pythonexpressionseriespython-polars

Evaluate expression inside custom class in polars


I am trying to extend the functionality of polars to manipulate categories of Enum. I am following this guide and this section of documentation

orig_df = pl.DataFrame({
    'idx': pl.int_range(5, eager=True),
    'orig_series': pl.Series(['Alpha', 'Omega', 'Alpha', 'Beta', 'Gamma'], 
                dtype=pl.Enum(['Alpha', 'Beta', 'Gamma', 'Omega']))})

@pl.api.register_expr_namespace('fct')
class CustomEnumMethodsCollection:
    def __init__(self, expr: pl.Expr):
        self._expr = expr
    
    def rev(self) -> pl.Expr:
        cats = self._expr.cat.get_categories()
        tmp_sr = self._expr.cast(pl.Categorical)
        return tmp_sr.cast(dtype=pl.Enum(cats.str.reverse()))

(orig_df
    .with_columns(rev_series=pl.col("orig_series").fct.rev())
    )

This errors with TypeError: Series constructor called with unsupported type 'Expr' for the values parameter because cats is an unevaluated expression, not a list or a series, as pl.Enum(dtype=) expects it. How do I evaluate the cats into the actual list/series to provide the new categories for my cast(pl.Enum) method?


Solution

  • You can use .map_batches()

    @pl.api.register_expr_namespace('fct')
    class CustomEnumMethodsCollection:
        def __init__(self, expr: pl.Expr):
            self._expr = expr
        
        def rev(self) -> pl.Expr:
            return self._expr.map_batches(lambda s:
                s.cast(pl.Enum(s.cat.get_categories().reverse()))
            )
    
    df = pl.DataFrame({
        'idx': pl.int_range(5, eager=True),
        'orig_series': pl.Series(['Alpha', 'Omega', 'Alpha', 'Beta', 'Gamma'], 
                       dtype=pl.Enum(['Alpha', 'Beta', 'Gamma', 'Omega']))})
    
    df.with_columns(rev_series=pl.col('orig_series').fct.rev()).schema
    
    Schema([('idx', Int64),
            ('orig_series', Enum(categories=['Alpha', 'Beta', 'Gamma', 'Omega'])),
            ('rev_series', Enum(categories=['Omega', 'Gamma', 'Beta', 'Alpha']))])