Recently I've encountered an issue running one of our PySpark jobs. While analyzing the stages in Spark UI I have noticed that the longest running stage takes 1.2 hours to run out of the total 2.5 hours that takes for the entire process to run.
Once I took a look at the stage details it was clear that I'm facing a severe data skew, causing a single task to run for the entire 1.2 hours while all other tasks finish within 23 seconds.
The DAG showed this stage involves Window Functions which helped me to quickly narrow down the problematic area to a few queries and finding the root cause -> The column, account
, that was being used in the Window.partitionBy("account")
had 25% of null values.
I don't have an interest to calculate the sum for the null accounts though I do need the involved rows for further calculations therefore I can't filter them out prior the window function.
Here is my window function query:
problematic_account_window = Window.partitionBy("account")
sales_with_account_total_df = sales_df.withColumn("sum_sales_per_account", sum(col("price")).over(problematic_account_window))
So we found the one to blame - What can we do now? How can we resolve the skew and the performance issue?
We basically have 2 solutions for this issue:
Solution 1:
account_window = Window.partitionBy("account")
# split to null and non null
non_null_accounts_df = sales_df.where(col("account").isNotNull())
only_null_accounts_df = sales_df.where(col("account").isNull())
# calculate the sum for the non null
sales_with_non_null_accounts_df = non_null_accounts_df.withColumn("sum_sales_per_account", sum(col("price")).over(account_window)
# union the calculated result and the non null df to the final result
sales_with_account_total_df = sales_with_non_null_accounts_df.unionByName(only_null_accounts_df, allowMissingColumns=True)
Solution 2:
SPARK_SHUFFLE_PARTITIONS = spark.conf.get("spark.sql.shuffle.partitions")
modified_sales_df = (sales_df
# create a random partition value that spans as much as number of shuffle partitions
.withColumn("random_salt_partition", lit(ceil(rand() * SPARK_SHUFFLE_PARTITIONS)))
# use the random partition values only in case the account value is null
.withColumn("salted_account", coalesce(col("account"), col("random_salt_partition")))
)
# modify the partition to use the salted account
salted_account_window = Window.partitionBy("salted_account")
# use the salted account window to calculate the sum of sales
sales_with_account_total_df = sales_df.withColumn("sum_sales_per_account", sum(col("price")).over(salted_account_window))
In my solution I've decided to use solution 2 since it didn't force me to create more dataframes for the sake of the calculation, and here is the result:
As seen above the salting technique helped resolving the skewness. The exact same stage now runs for a total of 5.5 minutes instead of 1.2 hours. The only modification in the code was the salting column in the partitionBy
. The comparison shown is based on the exact same cluster/nodes amount/cluster config.