Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Cannot understand the behaviour of pandas case_when used on Series with different indexes

I am trying to use the case_when of a pandas Series and I am not sure I understand why it behaves like below. I indicate the behaviour that looks odd to me. It seems it has to do with the index of the Series, but why?

import pandas as pd
print(pd.__version__)
# 2.3.0
a = pd.Series([1, 2, 3, 4, 5], index=['a', 'b', 'c', 'd', 'e'], dtype='int')
b = pd.Series([1, 2, 3, 4, 5], index=['A', 'B', 'C', 'D', 'E'], dtype='int')
res = a.case_when(
    [(a.gt(3), 'greater than 3'),
     (a.lt(3), 'less than 3')])
print(res)
# a       less than 3
# b       less than 3
# c                 3
# d    greater than 3
# e    greater than 3
res = a.case_when(
    [(a.gt(3), 'greater than 3'),
     (b.lt(3), 'less than 3')])
print(res)
# a       less than 3
# b       less than 3
# c       less than 3  <- why is this not 3?
# d    greater than 3
# e    greater than 3
res = a.case_when(
    [(b.gt(3), 'greater than 3'),
     (b.lt(3), 'less than 3')])
print(res)
# a    greater than 3 <- why is this not less than 3?
# b    greater than 3 <- why is this not less than 3?
# c    greater than 3 <- why is this not 3?
# d    greater than 3
# e    greater than 3
res = a.case_when(
    [(b.gt(3).to_list(), 'greater than 3'),
     (b.lt(3).to_list(), 'less than 3')])
print(res)
# a       less than 3
# b       less than 3
# c                 3
# d    greater than 3
# e    greater than 3
like image 395
karpan Avatar asked Oct 31 '25 02:10

karpan


2 Answers

case_when() uses mask() and the behaviour stems from there.

For example

a.mask(b.lt(3),'Replaced') 

leads to ouput

    0
a   Replaced
b   Replaced
c   Replaced
d   Replaced
e   Replaced

dtype: object

The reasons for that behaviour is described in the documentation for mask()

The mask method is an application of the if-then idiom. For each element in the calling DataFrame, if cond is False the element is used; otherwise the corresponding element from the DataFrame other is used.

So for each element in a it checks for the condition. But if your condition is a Series (as it is in your example), that means for each element index in a, mask()tries to find the value corresponding to the index of that element in b. But it cannot as

b.lt(3)

leads to a Series with the same index as b:

    0
A   True
B   True
C   False
D   False
E   False

dtype: bool

The idea behind providing a Series as a condition is exactly this kind of index lookup. Only when you provide a list for your condition or use the default index or identical indices for a and b you will get the behaviour that you want here.

However, I do not understand why a failure to find a boolean value for the given index in cond leads to masking/replacement. I would expect the opposite.

like image 89
Oskar Hofmann Avatar answered Nov 02 '25 16:11

Oskar Hofmann


Alignment

The issue you're having is caused by alignment. For that reason, I want to briefly explain what alignment is and why it exists.

Before it applies the condition mask to decide what to replace, it aligns the condition to the series it is masking.

When performing alignment, it will match up identically labelled rows in the series. This is one of the things that makes it very handy to manipulate unstructured data in Pandas.

For example, suppose you have two series, with data for elements a, b, and c, but in the wrong order.

import pandas as pd


a = pd.Series([1, 2, 3], index=['a', 'b', 'c'], dtype='int')
b = pd.Series([3, 2, 1], index=['c', 'b', 'a'], dtype='int')

print(a + b)

Rather than adding the elements up by position, it will match up a with a, b with b, and c with c.

a    2
b    4
c    6
dtype: int64

This begs the question - what happens when you attempt alignment, and the indices don't match? You can manually ask for alignment with the align() function.

import pandas as pd


a = pd.Series([1, 2, 3], index=['a', 'b', 'c'], dtype='int')
b = pd.Series([1], index=['a'], dtype='int')

b.align(a)[0]

Output:

a    1.0
b    NaN
c    NaN
dtype: float64

In this case, it fills the missing values with NaN.

Alignment for case_when()

How does Series.case_when() handle misaligned values? It is implemented in terms of Series.mask(). Here's what the docs say about how it handles those.

The mask method is an application of the if-then idiom. For each element in the calling DataFrame, if cond is False the element is used; otherwise the corresponding element from the DataFrame other is used. If the axis of other does not align with axis of cond Series/DataFrame, the misaligned index positions will be filled with True.

Source.

In other words, this means that in the presence of an unaligned element, Series.mask() will replace the element. Since case_when() calls Series.mask() like this, where default is the current state of the column, that means that case_when treats missing index elements as if that condition matches.

default = default.mask(
    condition, other=replacement, axis=0, inplace=False, level=None
)

Source.

In other words, the rule that case_when() implements can be thought of like this:1

  • For each row:
    • Loop through the conditions, stopping on the first True condition.
    • For the first True condition, replace the value with the associated replacement.
    • If no element of the condition matches the index for this row, that condition is treated as True.
    • If no condition is False or missing, then keep the original value.

Workaround

It is also worth mentioning that you can opt out of alignment. If the inputs are NumPy arrays, then Pandas will not re-align the array.

Example:

a.case_when(
    [(b.gt(3).values, 'greater than 3'),
     (b.lt(3).values, 'less than 3')])

You are doing something similar in your example with .to_list(). This is another way to opt out of alignment, although it is more expensive than .values.


1: This is not how case_when() is implemented - it actually loops over the conditions in reverse order, and unconditionally makes replacements, rather than stopping at the first match. In cases where multiple replacements are made, this means that the replacement closest to the beginning of the list is the one which actually sticks. It does this so that it can be implemented in a vectorized fashion.

like image 21
Nick ODell Avatar answered Nov 02 '25 17:11

Nick ODell



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!