pythondask

dask map_partitions strange behaviour


When I create a dask dataframe from pandas with 1 partition, then call map_partitions() on it, it seems to be called twice. If I have 5 partitions, it is called 6 times. In general, the function is called 1 extra time with some records I don't recognize. They don't show up in the output that is written to d.

The first call seems to be with a partition I don't recognize that has 2 records. This causes other problems but I won't mention them for now to keep the description succint.

Details:

python 3.9.18
dask 2024.8.0
pandas 2.0.3
import pandas as pd
import dask.dataframe as dd

df = pd.DataFrame({
    'a': list(range(100))
})
ddf = dd.from_pandas(df, npartitions=1)

def some_func(df):
    print (df.shape)
    print (df.head())
    return df

ddf = ddf.map_partitions(some_func)
print (ddf.compute().shape)

Output:

(2, 1)
  a
0 1
1 1
(100, 1)
  a
0 0
1 1
2 2
3 3
4 4
(100, 1)

I need to know how to avoid the extra call to the function with the records I don't recognize


Solution

  • The answer can be found in map_partitions documentation:

    By default, dask tries to infer the output metadata by running your provided function on some fake data. This works well in many cases, but can sometimes be expensive, or even fail. To avoid this, you can manually specify the output metadata with the meta keyword.

    Since you don't provide the meta kwarg, Dask tries to infer the output columns and dtypes by calling the function on fake data once.