pythonpython-polars

Polars pl.read_csv_batched -> batch_size is not respected at all


Using the pl.read_csv_batched() with batch_size=n, batches are read without any regard to batch_size whatsoever. I use polars version 1.29.0.

What's up with that. Can I use polars to import a large csv file without going the manual route? Why does batch_size not constrain the batch sizes whatsoever?


Solution

  • You've found a known bug

    Here's a workaround class to use in place of pl.read_csv_batched until that's fixed.

    from math import ceil
    import polars as pl
    class BatchedPolarsCSV:
        def __init__(self, source: str, batch_size: int, **kwargs):
            if "batch_size" in kwargs:
                del kwargs["batch_size"]
            self._reader = pl.read_csv_batched(source, **kwargs)
            self._batch_size = batch_size
            self._native_batch_size = None
            self._cached_rows = []
            self._exhausted = False
        
        def _batches_to_get(self, target_rows:int)->int:
            need_rows = target_rows - self._have_rows()
            if self._native_batch_size:
                return ceil(need_rows/self._native_batch_size)
            else:
                return 1
        def _have_rows(self):
            return sum([x.height for x in self._cached_rows])
        
        def next_batches(self, n: int) -> None | list[pl.DataFrame]:
            if self._exhausted:
                return None
            target_rows = n * self._batch_size
            while self._have_rows() < target_rows:
                chunk = self._reader.next_batches(self._batches_to_get(target_rows))
                if chunk is None:
                    break
                self._native_batch_size = chunk[0].height
                self._cached_rows.extend(chunk)
            if len(self._cached_rows)==0:
                self._exhausted=True
                return None
            df = pl.concat(self._cached_rows)
            return_part = df.slice(0, n * self._batch_size)
            cache_part = df.slice(n * self._batch_size)
            if cache_part.height > 0:
                self._cached_rows = [cache_part]
            else:
                self._cached_rows = []
            return return_part.with_columns(
                (pl.int_range(pl.len()) / self._batch_size)
                .floor()
                .alias(PARTITION_COL_NAME := "__parition_by_col")
            ).partition_by(PARTITION_COL_NAME, include_key=False)