pythonpython-polarspolars

Problem converting Pandas-based ATR function to be Polars-based


I have just started trying to convert some stock trading python code from being Pandas-based to use Polars. So, I am a complete newbie to Polars, not much better at Python, but I get by. I am migrating to Polars because of the increased performance of Polars. The following is my Pandas-based ATR function:

def ATR(df: pa.DataFrame, window_size: int = 14) -> pa.DataFrame:
    high, low, prev_close = df['high'], df['low'], df['close'].shift()
    tr_all = [high - low, high - prev_close, low - prev_close]
    tr_all = [tr.abs() for tr in tr_all]

    tr = pa.concat(tr_all, axis = 1).max(axis = 1)
    df['ATR'] = tr.ewm(alpha = 1/window_size, min_periods = window_size, adjust = False, ignore_na = True).mean()
    return df

I call this function like this:

raw_dd_df = ATR(raw_dd_df, window_size = slowline_period)

and produces results similar to the following:

           date     open     high      low  ...  volume     vwap      mid       ATR
300  2024-01-01  1.27300  1.27330  1.26936  ...       0  1.27189  1.27300       NaN
299  2024-01-02  1.27291  1.27597  1.26105  ...  242445  1.26779  1.26707       NaN
298  2024-01-03  1.26123  1.26765  1.26123  ...  296035  1.26414  1.26384       NaN
297  2024-01-04  1.26644  1.27295  1.26565  ...  270883  1.26830  1.26730       NaN
296  2024-01-05  1.26816  1.27710  1.26113  ...  333038  1.26949  1.26987       NaN
..          ...      ...      ...      ...  ...     ...      ...      ...       ...
4    2024-12-26  1.25296  1.25474  1.25005  ...  245898  1.25241  1.25242  0.009308
3    2024-12-27  1.25187  1.25925  1.25046  ...  234639  1.25464  1.25443  0.009282
2    2024-12-29  1.25756  1.25756  1.25756  ...       0  1.25756  1.25756  0.008846
1    2024-12-30  1.25726  1.26070  1.25059  ...  243089  1.25587  1.25610  0.008910
0    2024-12-31  1.25493  1.25688  1.25048  ...  227765  1.25336  1.25303  0.008784

The following is my attempt to rewite the function to use Polars:

def pl_ATR(df: pl.DataFrame, window_size: int = 14) -> pl.DataFrame:
    high, low, prev_close = df['High'], df['Low'], df['Close'].shift()
    tr_all = [high - low, high - prev_close, low - prev_close]
    tr_all = [tr.abs() for tr in tr_all]

    tr = pl.concat(tr_all, rechunk = True).max()
    df['ATR'] = pl.Expr.ewm_mean(tr, alpha = 1/window_size, min_samples = window_size, adjust = False, ignore_nulls = True)
    return df

I am not sure if i have got tr = pl.concat(tr_all, rechunk = True).max() right, but I get errors at the next line (when trying to do pl.Expr.ewm_mean):

Traceback (most recent call last):
  File "/home/stuart/Projects/Python/Trading/Scratches/Polars/demo_polars.py", line 48, in <module>
    df = pl_ATR(df, window_size = 20)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/stuart/Projects/Python/Trading/Scratches/Polars/demo_polars.py", line 37, in pl_ATR
    df['ATR'] = pl.Expr.ewm_mean(tr, alpha = 1/window_size, min_samples = window_size, adjust = False, ignore_nulls = True)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/stuart/Projects/Python/Env/lib/python3.12/site-packages/polars/_utils/deprecation.py", line 119, in wrapper
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/stuart/Projects/Python/Env/lib/python3.12/site-packages/polars/expr/expr.py", line 9488, in ewm_mean
    return self._from_pyexpr(
           ^^^^^^^^^^^^^^^^^
AttributeError: 'float' object has no attribute '_from_pyexpr'

I have spent hours trying to get it working, but to no avail. Can anybody help please?

Regards, Stuart


Solution

  • Let's start with some data

    from datetime import date
    
    import pandas as pd
    import polars as pl
    
    data = {
        "date": pl.date_range(date(2025, 4, 1), date(2025, 4, 6), eager=True),
        "high": [102, 104, 103, 107, 110, 112],
        "low": [98, 99, 100, 103, 106, 108],
        "close": [101, 103, 102, 106, 109, 111],
    }
    
    df_pl = pl.DataFrame(data)
    df_pd = df_pl.to_pandas()
    
    # shape: (6, 4)
    # ┌────────────┬──────┬─────┬───────┐
    # │ date       ┆ high ┆ low ┆ close │
    # │ ---        ┆ ---  ┆ --- ┆ ---   │
    # │ date       ┆ i64  ┆ i64 ┆ i64   │
    # ╞════════════╪══════╪═════╪═══════╡
    # │ 2025-04-01 ┆ 102  ┆ 98  ┆ 101   │
    # │ 2025-04-02 ┆ 104  ┆ 99  ┆ 103   │
    # │ 2025-04-03 ┆ 103  ┆ 100 ┆ 102   │
    # │ 2025-04-04 ┆ 107  ┆ 103 ┆ 106   │
    # │ 2025-04-05 ┆ 110  ┆ 106 ┆ 109   │
    # │ 2025-04-06 ┆ 112  ┆ 108 ┆ 111   │
    # └────────────┴──────┴─────┴───────┘
    

    Calling your (pandas) ATR function will give us an output to check against. Using a window size of 3 for simplicity.

    ATR(df_pd, window_size=3)
    #         date  high  low  close       ATR
    # 0 2025-04-01   102   98    101       NaN
    # 1 2025-04-02   104   99    103       NaN
    # 2 2025-04-03   103  100    102  3.888889
    # 3 2025-04-04   107  103    106  4.259259
    # 4 2025-04-05   110  106    109  4.172840
    # 5 2025-04-06   112  108    111  4.115226
    

    With Polars, when you see in the docs polars.Expr.ewm_mean, that means call the ewm_mean method on an expression (e.g., my_expr.ewm_mean(...)). Expressions can be created with pl.col("some_column") and are also returned when doing computations between each other (e.g., pl.col("high") - pl.col("low")). In practice, that means you need your "true range" calculation as an expression. Read the expressions and contexts section of the user guide for further details.

    As a general rule in Polars, prefer expressions rather than accessing columns as Series' (df["some_column"])

    With that said, here is a Polars solution. Not dissimilar to your pandas code, just using expressions.

    def pl_ATR(
        high: str | pl.Expr = "high",
        low: str | pl.Expr = "low",
        close: str | pl.Expr = "close",
        *,
        window_size: int = 14,
    ) -> pl.Expr:
        # If caller passed in strings for column names, convert them to expressions
        # Calling `pl.col("column_name")` gives us a reference to a column
        # and is of type expression (pl.Expr)
        # This isn't required, just makes things nice for the caller
        # and allows different column names to be passed
        if isinstance(high, str): high = pl.col(high)
        if isinstance(low, str): low = pl.col(low)
        if isinstance(close, str): close = pl.col(close)
    
        # Define the previous close as an expression give we will re-use it
        prev_close = close.shift()
    
        # `max_horizontal` is like `max(axis=1)` in pandas - it operates horizontally
        # In Polars, horizontal operations generally have a dedicated function
        # rather than an `axis` parameter
        true_range = pl.max_horizontal(
          high - low, # abs() isn't needed here as this will never be negative
          (high - prev_close).abs(),
          (low - prev_close).abs(),
        )
        # We also return an expression, that can be evaluated in a context
        # (read the user guide link if confused)
        return true_range.ewm_mean(
            alpha=1 / window_size,
            min_samples=window_size,
            adjust=False,
            ignore_nulls=True
        )
    

    Now using this function, we can see it matches the pandas output!

    df_pl.with_columns(ATR=pl_ATR(window_size=3))
    # shape: (6, 5)
    # ┌────────────┬──────┬─────┬───────┬──────────┐
    # │ date       ┆ high ┆ low ┆ close ┆ ATR      │
    # │ ---        ┆ ---  ┆ --- ┆ ---   ┆ ---      │
    # │ date       ┆ i64  ┆ i64 ┆ i64   ┆ f64      │
    # ╞════════════╪══════╪═════╪═══════╪══════════╡
    # │ 2025-04-01 ┆ 102  ┆ 98  ┆ 101   ┆ null     │
    # │ 2025-04-02 ┆ 104  ┆ 99  ┆ 103   ┆ null     │
    # │ 2025-04-03 ┆ 103  ┆ 100 ┆ 102   ┆ 3.888889 │
    # │ 2025-04-04 ┆ 107  ┆ 103 ┆ 106   ┆ 4.259259 │
    # │ 2025-04-05 ┆ 110  ┆ 106 ┆ 109   ┆ 4.17284  │
    # │ 2025-04-06 ┆ 112  ┆ 108 ┆ 111   ┆ 4.115226 │
    # └────────────┴──────┴─────┴───────┴──────────┘