Pandas get topmost n records within each group

PythonPandasGreatest N-per-GroupWindow FunctionsTop N

Python Problem Overview


Suppose I have pandas DataFrame like this:

df = pd.DataFrame({'id':[1,1,1,2,2,2,2,3,4],'value':[1,2,3,1,2,3,4,1,1]})

which looks like:

   id  value
0   1      1
1   1      2
2   1      3
3   2      1
4   2      2
5   2      3
6   2      4
7   3      1
8   4      1

I want to get a new DataFrame with top 2 records for each id, like this:

   id  value
0   1      1
1   1      2
3   2      1
4   2      2
7   3      1
8   4      1

I can do it with numbering records within group after groupby:

dfN = df.groupby('id').apply(lambda x:x['value'].reset_index()).reset_index()

which looks like:

   id  level_1  index  value
0   1        0      0      1
1   1        1      1      2
2   1        2      2      3
3   2        0      3      1
4   2        1      4      2
5   2        2      5      3
6   2        3      6      4
7   3        0      7      1
8   4        0      8      1

then for the desired output:

dfN[dfN['level_1'] <= 1][['id', 'value']]

Output:

   id  value
0   1      1
1   1      2
3   2      1
4   2      2
7   3      1
8   4      1

But is there more effective/elegant approach to do this? And also is there more elegant approach to number records within each group (like SQL window function row_number()).

Python Solutions


Solution 1 - Python

Did you try

df.groupby('id').head(2)

Output generated:

       id  value
id             
1  0   1      1
   1   1      2 
2  3   2      1
   4   2      2
3  7   3      1
4  8   4      1

(Keep in mind that you might need to order/sort before, depending on your data)

EDIT: As mentioned by the questioner, use

df.groupby('id').head(2).reset_index(drop=True)

to remove the MultiIndex and flatten the results:

    id  value
0   1      1
1   1      2
2   2      1
3   2      2
4   3      1
5   4      1

Solution 2 - Python

Since 0.14.1, you can now do nlargest and nsmallest on a groupby object:

In [23]: df.groupby('id')['value'].nlargest(2)
Out[23]: 
id   
1   2    3
    1    2
2   6    4
    5    3
3   7    1
4   8    1
dtype: int64

There's a slight weirdness that you get the original index in there as well, but this might be really useful depending on what your original index was.

If you're not interested in it, you can do .reset_index(level=1, drop=True) to get rid of it altogether.

(Note: From 0.17.1 you'll be able to do this on a DataFrameGroupBy too but for now it only works with Series and SeriesGroupBy.)

Solution 3 - Python

Sometimes sorting the whole data ahead is very time consuming. We can groupby first and doing topk for each group:

g = df.groupby(['id']).apply(lambda x: x.nlargest(topk,['value'])).reset_index(drop=True)

Solution 4 - Python

df.groupby('id').apply(lambda x : x.sort_values(by = 'value', ascending = False).head(2).reset_index(drop = True))
  • Here sort values ascending false gives similar to nlargest and True gives similar to nsmallest.
  • The value inside the head is the same as the value we give inside nlargest to get the number of values to display for each group.
  • reset_index is optional and not necessary.

Solution 5 - Python

This works for duplicated values

If you have duplicated values in top-n values, and want only unique values, you can do like this:

import pandas as pd

ifile = "https://raw.githubusercontent.com/bhishanpdl/Shared/master/data/twitter_employee.tsv"
df = pd.read_csv(ifile,delimiter='\t')
print(df.query("department == 'Audit'")[['id','first_name','last_name','department','salary']])

    id first_name last_name department  salary
24  12   Shandler      Bing      Audit  110000
25  14      Jason       Tom      Audit  100000
26  16     Celine    Anston      Audit  100000
27  15    Michale   Jackson      Audit   70000

If we do not remove duplicates, for the audit department we get top 3 salaries as 110k,100k and 100k.
If we want to have not-duplicated salaries per each department, we can do this:

(df.groupby('department')['salary']
 .apply(lambda ser: ser.drop_duplicates().nlargest(3))
 .droplevel(level=1)
 .sort_index()
 .reset_index()
)

This gives

department	salary
0	Audit	110000
1	Audit	100000
2	Audit	70000
3	Management	250000
4	Management	200000
5	Management	150000
6	Sales	220000
7	Sales	200000
8	Sales	150000





Attributions

All content for this solution is sourced from the original question on Stackoverflow.

The content on this page is licensed under the Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.

Content TypeOriginal AuthorOriginal Content on Stackoverflow
QuestionRoman PekarView Question on Stackoverflow
Solution 1 - PythondorvakView Answer on Stackoverflow
Solution 2 - PythonLondonRobView Answer on Stackoverflow
Solution 3 - PythonChaffee ChenView Answer on Stackoverflow
Solution 4 - PythonPragatheeswaranView Answer on Stackoverflow
Solution 5 - PythonBhishanPoudelView Answer on Stackoverflow