is it possible to get the invalid rows of a PanderaPolars schema validation?
Other questions about PanderaPolars:
I have started exploring Pandera, using Polars dataframe.
I wrote a simple schema example below:
import pandera.polars as pa
import warnings
import polars as pl
from pandera.polars import PolarsData
def is_not_null(data: PolarsData) -> pl.LazyFrame:
"""Return a LazyFrame with a single boolean column."""
return data.lazyframe.select(pl.col(data.key).is_not_null())
def mean_less_than(data: PolarsData, threshold=0) -> pl.LazyFrame:
"""Return a LazyFrame with a single boolean column."""
return data.lazyframe.select(pl.col(data.key).mean() < threshold)
df = pl.read_csv('file.csv', separator=',', quote_char='"', infer_schema_length = 100000)
schema = pa.DataFrameSchema({
"ISIN": pa.Column(
str,
checks=[
pa.Check.str_matches(r'^\w{12}$', ignore_na=True,raise_warning=True, title="ISIN pattern check", description="ISIN has to be 12 characters long"),
pa.Check(is_not_null, ignore_na=True,raise_warning=True, title="ISIN null check", description="ISIN cannot be null")
], nullable = True
),
"ID": pa.Column(
str,
checks=[
pa.Check.str_matches(r'^\w{9}$', ignore_na=True,raise_warning=True, title="ID pattern check", description="ID has to be 9 characters long"),
pa.Check(is_not_null, ignore_na=True,raise_warning=True, title="ID null check", description="ID cannot be null")
], nullable = True
),
"PRICE": pa.Column(
float,
checks=[
pa.Check.greater_than_or_equal_to(0, ignore_na=True,raise_warning=True),
pa.Check.less_than_or_equal_to(10000, ignore_na=True,raise_warning=True),
pa.Check.in_range(min_value=0, max_value=10000, include_min=True, include_max=True, ignore_na=True,raise_warning=True),
pa.Check(mean_less_than, threshold=97.5, ignore_na=True,raise_warning=True, title="PRICE MEAN check", description="MEAN(PRICE) has to be less than 0", error="mean above threshold")
], nullable = True
)
})
with warnings.catch_warnings(record=True) as caught_warnings:
warnings.simplefilter("always")
validated_df = df.pipe(schema_with_custom_checks.validate)
for warning in caught_warnings:
print(warning)
While I was able to run it using raise_warning=true and catching the warning messages, I was not able to return the rows that fail the checks.
I know that in Pandas I managed to do it using some changes (raise_warning=False and capturing the schema errors):
import pandera as pa
import pandas as pd
import warnings
df = pd.read_csv('file.csv')
check_notnull = pa.Check(lambda s: s.notnull(), ignore_na=False, raise_warning=False)
check_isin = pa.Check.str_matches(r'^\w{12}$', ignore_na=True,raise_warning=False)
check_id = pa.Check.str_matches(r'^\w{9}$', ignore_na=True,raise_warning=False)
check_price_range = pa.Check.in_range(min_value=0, max_value=10000, include_min=True, include_max=True, ignore_na=True,raise_warning=False)
check_price_average = pa.Check(lambda s: s.mean()>100, ignore_na=True, raise_warning=False, error="mean below threshold")
schema = pa.DataFrameSchema({
"ISIN": pa.Column(str, [check_notnull, check_isin], nullable=True),
"ID": pa.Column(str, [check_notnull, check_id], nullable=True),
"PRICE": pa.Column(float, [check_notnull, check_price_range, check_price_average], nullable=True)
})
try:
schema(df, lazy=True)
except pa.errors.SchemaErrors as exc:
filtered_df = df[df.index.isin(exc.failure_cases["index"])]
print(f"filtered df:\n{filtered_df}")
I tried to replicate the same of pandas using polars as follows:
import pandera.polars as pa
import warnings
import polars as pl
from pandera.polars import PolarsData
def is_not_null(data: PolarsData) -> pl.LazyFrame:
"""Return a LazyFrame with a single boolean column."""
return data.lazyframe.select(pl.col(data.key).is_not_null())
def mean_less_than(data: PolarsData, threshold=0) -> pl.LazyFrame:
"""Return a LazyFrame with a single boolean column."""
return data.lazyframe.select(pl.col(data.key).mean() < threshold)
df = pl.read_csv('file.csv', separator=',', quote_char='"', infer_schema_length = 100000)
schema = pa.DataFrameSchema({
"ISIN": pa.Column(
str,
checks=[
pa.Check.str_matches(r'^\w{12}$', ignore_na=True,raise_warning=False, title="ISIN pattern check", description="ISIN has to be 12 characters long"),
pa.Check(is_not_null, ignore_na=True,raise_warning=False, title="ISIN null check", description="ISIN cannot be null")
], nullable = True
),
"ID": pa.Column(
str,
checks=[
pa.Check.str_matches(r'^\w{9}$', ignore_na=True,raise_warning=False, title="ID pattern check", description="ID has to be 9 characters long"),
pa.Check(is_not_null, ignore_na=True,raise_warning=False, title="ID null check", description="ID cannot be null")
], nullable = True
),
"PRICE": pa.Column(
float,
checks=[
pa.Check.greater_than_or_equal_to(0, ignore_na=True,raise_warning=False),
pa.Check.less_than_or_equal_to(10000, ignore_na=True,raise_warning=False),
pa.Check.in_range(min_value=0, max_value=10000, include_min=True, include_max=True, ignore_na=True,raise_warning=False),
pa.Check(mean_less_than, threshold=97.5, ignore_na=True,raise_warning=False, title="PRICE MEAN check", description="MEAN(PRICE) has to be less than 0", error="mean above threshold")
], nullable = True
)
})
try:
schema(df, lazy=True)
except pa.errors.SchemaErrors as exc:
filtered_df = df[df.index.isin(exc.failure_cases["index"])]
print(f"filtered df:\n{filtered_df}")
but I got AttributeError: 'DataFrame' object has no attribute 'index'
My question is if it's possible to return the invalid rows using PanderaPolars.
The error is expected as you are using Pandas syntax:
df[df.index.isin(exc.failure_cases["index"])]
exc.failure_cases
appears to be a Polars DataFrame:
>>> type(exc.failure_cases)
<class 'polars.dataframe.frame.DataFrame'>
index
is a column of type i32
that contains (potentially) duplicates and nulls.
>>> exc.failure_cases
shape: (7, 6)
┌──────────────┬────────────────┬────────┬─────────────────────────┬──────────────┬───────┐
│ failure_case ┆ schema_context ┆ column ┆ check ┆ check_number ┆ index │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ str ┆ str ┆ i32 ┆ i32 │
╞══════════════╪════════════════╪════════╪═════════════════════════╪══════════════╪═══════╡
│ Int64 ┆ Column ┆ ISIN ┆ dtype('String') ┆ null ┆ null │
│ 1 ┆ Column ┆ ISIN ┆ str_matches('^\w{12}$') ┆ 0 ┆ 0 │
│ 4 ┆ Column ┆ ISIN ┆ str_matches('^\w{12}$') ┆ 0 ┆ 1 │
│ Int64 ┆ Column ┆ ID ┆ dtype('String') ┆ null ┆ null │
│ 2 ┆ Column ┆ ID ┆ str_matches('^\w{9}$') ┆ 0 ┆ 0 │
│ 5 ┆ Column ┆ ID ┆ str_matches('^\w{9}$') ┆ 0 ┆ 1 │
│ Int64 ┆ Column ┆ PRICE ┆ dtype('Float64') ┆ null ┆ null │
└──────────────┴────────────────┴────────┴─────────────────────────┴──────────────┴───────┘
The question then becomes: how to find the corresponding row numbers in df
?
A direct translation of the Pandas code is probably:
filtered_df = df.filter(pl.int_range(pl.len()).is_in(exc.failure_cases["index"]))
But there are a few different ways you could write it.