I would like to code a logger for polars using the Custom Namespace API.
For instance, starting from:
import logging
import polars as pl
penguins_pl = pl.read_csv("https://raw.githubusercontent.com/mwaskom/seaborn-data/master/penguins.csv")
my objective is to have
penguins_pl.mine.startlog().filter(pl.col("species")=="Adelie").mine.endlog()
log "192 rows has been removed".
The plan is to have startlog save the shape of the dataframe in a temporary variable and then reuse that in endlog.
I have tried this:
@pl.api.register_dataframe_namespace("mine")
class MinePolarsDataframeUtils:
def __init__(self, df: pl.DataFrame):
self._df = df
def startlog(self):
self._shape = self._df.shape
return(self._df)
def endlog(self):
if not self._shape:
raise ValueError("startlog() must be called before endlog()")
dr = self._shape[0] - self._df.shape[0]
dc = self._shape[1] - self._df.shape[1]
logging.getLogger("polars_logger").info(f"Rows added {dr}, cols added {dc}")
self._shape = None
return(self._df)
But it doesn't work because MinePolarsDataframeUtils is initialized both when startlog and when endlog are called.
That is, when endlog is called, the class starts from scratch, and the value of self._shape saved by startlog is not carried over.
That is, when endlog is called, self._shape is undefined.
How can I keep custom variables between calls when extending polars?
Related: Logging operation results in pandas (equivalent of STATA/tidylog)
May be an exaggeration for some use cases, but I'd recommend using context managers + contextvars
import contextvars
import logging
import polars as pl
logging.basicConfig()
logger = logging.getLogger("polars_logging")
logger.setLevel(logging.INFO)
class LogCtx:
def __init__(self):
self._df: pl.DataFrame | None = None
self._token: contextvars.Token["LogCtx"] | None = None
def __enter__(self):
self._token = logging_context.set(self)
def __exit__(self, exc_type, exc_val, exc_tb):
if self._token is not None:
logging_context.reset(self._token)
self._df = None
# If you want to enforce having a custom context:
logging_context: contextvars.ContextVar[LogCtx] = contextvars.ContextVar("logging_context")
# If you want to use a global context by default:
# logging_context: contextvars.ContextVar[LogCtx] = contextvars.ContextVar("logging_context", default=LogCtx())
@pl.api.register_dataframe_namespace("log")
class LogCtxUtils:
def __init__(self, df: pl.DataFrame):
self._df = df
def start(self):
ctx = logging_context.get()
ctx._df = self._df
return self._df
def step(self):
ctx = logging_context.get()
assert ctx._df is not None, "Must log.start() before using log.step()"
dr = ctx._df.shape[0] - self._df.shape[0]
dc = ctx._df.shape[1] - self._df.shape[1]
logger.info(f"Rows added {dr}, cols added {dc}")
ctx._df = self._df
return self._df
df = pl.DataFrame({"x": range(5)})
with LogCtx():
df.log.start().with_columns(y="x").log.step().unpivot().log.step() # pyright: ignore[reportAttributeAccessIssue]
Alternatively you could avoid contextvars and pass it directly to the methods each time, e.g. def start(self, ctx) + ctx = LogCtx() + df.log.start(ctx)