apache-sparkpyspark

How to properly checkpoint a dataframe in PySpark


Suppose I'm reading a very (VERY) large table stored in PATH. After filtering the table and selecting a couple of columns to make it compatible with df2, I proceed to join df and df2 on a newly created column ('id'). This join is super expensive, so I'd like to break the logical plan because I do a ton of aggregations in a groupBy statement afterwards (which is also a really expensive stage).

What's the correct way to checkpoint the query to break the logical plan after joining both tables together?

Option 1: Mid-query

df = (
    spark.read.parquet(PATH)
    .where(FILTERS)
    .select(COLUMNS)
    .join(
        other=df2,
        on='id',
        how='inner'
    )
    .checkpoint()
    .groupBy('id')
    .agg(AGGREGATES)
)

Option 2: Separate queries

df = (
    spark.read.parquet(PATH)
    .where(FILTERS)
    .select(COLUMNS)
    .join(
        other=df2,
        on='id',
        how='inner'
    )
    .checkpoint()
)

# Resume in separate query
df = (
    df
    .groupBy('id')
    .agg(AGGREGATES)
)

Solution

  • The two are equivalent. In both options your groupBy is operating on a checkpointed dataframe. You just have to choose one based on which is more readable. The second option is more flexible because you can do other things with the checkpointed dataframe as well. Like

    # Resume in separate query
    grouped1 = (
        df
        .groupBy('id')
        .agg(AGGREGATES)
    )
    grouped2 = (
        df
        .groupBy('id2')
        .agg(AGGREGATES)
    )