Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

In pandas, how to pivot a dataframe on a categorical series with missing categories?

I have a pandas dataframe with a categorical series that has missing categories.

In the example shown below, group has the categories "a", "b", and "c", but there are no cases of "c" in the dataframe.

import pandas as pd
dfr = pd.DataFrame({
    "id": ["111", "222", "111", "333"], 
    "group": ["a", "a", "b", "b"], 
    "value": [1, 4, 9, 16]})
dfr["group"] = pd.Categorical(dfr["group"], categories=["a", "b", "c"])
dfr.pivot(index="id", columns="group")

The resulting pivoted dataframe has columns a and b. I expected a c column containing all missing value as well.

      value      
group     a     b
id               
111     1.0   9.0
222     4.0   NaN
333     NaN  16.0

How can I pivot a dataframe on a categorical series to include columns with all categories, regardless of whether they were present in the original dataframe?

like image 739
Richie Cotton Avatar asked Dec 01 '21 15:12

Richie Cotton


People also ask

How do you filter categorical data in Pandas?

For categorical data you can use Pandas string functions to filter the data. The startswith() function returns rows where a given column contains values that start with a certain value, and endswith() which returns rows with values that end with a certain value.

How do I pivot a Pandas DataFrame?

DataFrame - pivot() function The pivot() function is used to reshaped a given DataFrame organized by given index / column values. This function does not support data aggregation, multiple values will result in a MultiIndex in the columns. Column to use to make new frame's index. If None, uses existing index.

What is the difference between pivot and pivot table?

Basically, the pivot_table() function is a generalization of the pivot() function that allows aggregation of values — for example, through the len() function in the previous example. Pivot only works — or makes sense — if you need to pivot a table and show values without any aggregation. Here's an example.

How do you find which columns have missing values in Pandas?

You can use the isnull() or isna() method of pandas. DataFrame and Series to check if each element is a missing value or not. isnull() is an alias for isna() , whose usage is the same.


Video Answer


2 Answers

pd.pivot_table has a dropna argument which dictates dropping or not value columns full of NaNs.

Try setting it to False:

import pandas as pd
dfr = pd.DataFrame({
    "id": ["111", "222", "111", "333"], 
    "group": ["a", "a", "b", "b"], 
    "value": [1, 4, 9, 16]})
dfr["group"] = pd.Categorical(dfr["group"], categories=["a", "b", "c"])
pd.pivot_table(dfr, index="id", columns="group", dropna=False)
like image 109
Learning is a mess Avatar answered Nov 01 '22 18:11

Learning is a mess


You can reindex. This will work even if your value column is not numerical (unlike pivot_table):

output = (dfr.pivot(index="id", columns="group")
             .reindex(columns=pd.MultiIndex.from_product([["value"],
                                                          dfr["group"].cat.categories]
                                                         )
                      )
             )

>>> output
    value          
        a     b   c
id                 
111   1.0   9.0 NaN
222   4.0   NaN NaN
333   NaN  16.0 NaN
like image 39
not_speshal Avatar answered Nov 01 '22 17:11

not_speshal