pythonpandasseries

Cannot understand the behaviour of pandas case_when used on Series with different indexes


I am trying to use the case_when of a pandas Series and I am not sure I understand why it behaves like below. I indicate the behaviour that looks odd to me. It seems it has to do with the index of the Series, but why?

import pandas as pd
print(pd.__version__)
# 2.3.0
a = pd.Series([1, 2, 3, 4, 5], index=['a', 'b', 'c', 'd', 'e'], dtype='int')
b = pd.Series([1, 2, 3, 4, 5], index=['A', 'B', 'C', 'D', 'E'], dtype='int')
res = a.case_when(
    [(a.gt(3), 'greater than 3'),
     (a.lt(3), 'less than 3')])
print(res)
# a       less than 3
# b       less than 3
# c                 3
# d    greater than 3
# e    greater than 3
res = a.case_when(
    [(a.gt(3), 'greater than 3'),
     (b.lt(3), 'less than 3')])
print(res)
# a       less than 3
# b       less than 3
# c       less than 3  <- why is this not 3?
# d    greater than 3
# e    greater than 3
res = a.case_when(
    [(b.gt(3), 'greater than 3'),
     (b.lt(3), 'less than 3')])
print(res)
# a    greater than 3 <- why is this not less than 3?
# b    greater than 3 <- why is this not less than 3?
# c    greater than 3 <- why is this not 3?
# d    greater than 3
# e    greater than 3
res = a.case_when(
    [(b.gt(3).to_list(), 'greater than 3'),
     (b.lt(3).to_list(), 'less than 3')])
print(res)
# a       less than 3
# b       less than 3
# c                 3
# d    greater than 3
# e    greater than 3

Solution

  • Alignment

    The issue you're having is caused by alignment. For that reason, I want to briefly explain what alignment is and why it exists.

    Before it applies the condition mask to decide what to replace, it aligns the condition to the series it is masking.

    When performing alignment, it will match up identically labelled rows in the series. This is one of the things that makes it very handy to manipulate unstructured data in Pandas.

    For example, suppose you have two series, with data for elements a, b, and c, but in the wrong order.

    import pandas as pd
    
    
    a = pd.Series([1, 2, 3], index=['a', 'b', 'c'], dtype='int')
    b = pd.Series([3, 2, 1], index=['c', 'b', 'a'], dtype='int')
    
    print(a + b)
    

    Rather than adding the elements up by position, it will match up a with a, b with b, and c with c.

    a    2
    b    4
    c    6
    dtype: int64
    

    This begs the question - what happens when you attempt alignment, and the indices don't match? You can manually ask for alignment with the align() function.

    import pandas as pd
    
    
    a = pd.Series([1, 2, 3], index=['a', 'b', 'c'], dtype='int')
    b = pd.Series([1], index=['a'], dtype='int')
    
    b.align(a)[0]
    

    Output:

    a    1.0
    b    NaN
    c    NaN
    dtype: float64
    

    In this case, it fills the missing values with NaN.

    Alignment for case_when()

    How does Series.case_when() handle misaligned values? It is implemented in terms of Series.mask(). Here's what the docs say about how it handles those.

    The mask method is an application of the if-then idiom. For each element in the calling DataFrame, if cond is False the element is used; otherwise the corresponding element from the DataFrame other is used. If the axis of other does not align with axis of cond Series/DataFrame, the misaligned index positions will be filled with True.

    Source.

    In other words, this means that in the presence of an unaligned element, Series.mask() will replace the element. Since case_when() calls Series.mask() like this, where default is the current state of the column, that means that case_when treats missing index elements as if that condition matches.

    default = default.mask(
        condition, other=replacement, axis=0, inplace=False, level=None
    )
    

    Source.

    In other words, the rule that case_when() implements can be thought of like this:1

    Workaround

    It is also worth mentioning that you can opt out of alignment. If the inputs are NumPy arrays, then Pandas will not re-align the array.

    Example:

    a.case_when(
        [(b.gt(3).values, 'greater than 3'),
         (b.lt(3).values, 'less than 3')])
    

    You are doing something similar in your example with .to_list(). This is another way to opt out of alignment, although it is more expensive than .values.


    1: This is not how case_when() is implemented - it actually loops over the conditions in reverse order, and unconditionally makes replacements, rather than stopping at the first match. In cases where multiple replacements are made, this means that the replacement closest to the beginning of the list is the one which actually sticks. It does this so that it can be implemented in a vectorized fashion.