apache-sparkpyspark

Performance Degradation with mapInPandas in Spark 3.5.*


After upgrading to Spark 3.5.*, I noticed a significant performance degradation when using mapInPandas for computationally intensive tasks, in this case computing SHAP values in parallel. Performance remained consistent across Spark versions from 3.1 to 3.4. However, after upgrading to Spark 3.5, execution time has increased substantially.

Minimal Reproducible Example

I've created a minimal reproducible example to isolate the issue as much as I could. Below are the execution times per SHAP iteration using this code:

Model Size (MB) Spark 3.4.4 (s/it) Spark 3.5.0 (s/it)
lgb-s 20 1 5
lgb-m 52 2.5 13
lgb-l 110 5 40

As shown, execution time has increased by approximately 5-8x after upgrading to Spark 3.5.

import time
import os
import sys
import findspark
import pandas as pd
import shap
import lightgbm as lgb
import requests
from typing import Iterable
from sklearn.model_selection import train_test_split

findspark.init()
os.environ["PYSPARK_PYTHON"] = sys.executable

import pyspark.sql
import pyspark.sql.types as T


def explain(df, model, background_data):
    def compute_shap(iterable: Iterable[pd.DataFrame]) -> Iterable[pd.DataFrame]:
        for i, batch in enumerate(iterable):
            if i > 0:
                break

            explainer.shap_values(batch, silent=False)
            yield pd.DataFrame(columns=["dummy"])

    explainer = shap.KernelExplainer(
        model=model.predict,
        data=background_data,
        keep_index=True,
        link="identity",
    )

    print("Computing shap values")
    t1 = time.time()

    schema = T.StructType([T.StructField("dummy", T.IntegerType())])
    shap_values = df.mapInPandas(compute_shap, schema=schema)
    shap_values.collect()

    t2 = time.time()
    print(f"Elapsed time: {round(t2 - t1, 2)} seconds")


conf = pyspark.SparkConf().setAppName("bug")
# Set maxRecordsPerBatch to 1 since we are interested in a single iteration
conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "1")
spark = pyspark.sql.SparkSession.builder.config(conf=conf).getOrCreate()

# NOTE: set size to train lgb model with different number of estimators
# s: n_estimator=1000, m: n_estimators=2500, l: n_estimators=5000
size = "s"

# Download the dataset if it doesn't exist
url = "https://raw.githubusercontent.com/saul-chirinos/Electricity-demand-forecasting-in-Panama/master/Data/continuous%20dataset.csv"
filename = "panama.csv"

if not os.path.isfile(filename):
    response = requests.get(url)
    response.raise_for_status()
    with open(filename, "wb") as file:
        file.write(response.content)

# Load data
data = pd.read_csv(filename).drop(columns=["datetime", "QV2M_san", "T2M_san", "T2M_toc"])
X, y = data.drop(columns=["nat_demand"]), data["nat_demand"]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Train model
params = {"n_estimators": 1000 if size == "s" else 2500 if size == "m" else 5000, "num_leaves": 256}
train, test = lgb.Dataset(X_train, label=y_train), lgb.Dataset(X_test, label=y_test)
predictor = lgb.train(params=params, train_set=train, valid_sets=[test])
predictor.save_model(f"lgb-{size}.txt")

# NOTE: use this for multiple runs to avoid retraining
# 
# Load model
# predictor = lgb.Booster(model_file=f"lgb-{size}.txt")

print(f"lgb-{size}: {os.path.getsize(f'lgb-{size}.txt') / (1024 * 1024):.2f} MB")

# Select samples for background data and to be explained
background_data = X_train.iloc[:10]
df = spark.createDataFrame(X_test.iloc[:100]).coalesce(1)

print(f"{pyspark.__version__=}")
explain(df, predictor, background_data)

What I Tried

Questions

Updates

  1. Tested with Spark 4.0.0-preview2 – the issue persists
  2. Finally identified the root cause: since Spark 3.5, SHAP computations inside mapInPandas are using only one core, whereas previously all cores were utilized. The question is: why?

Explanation 🧐

I've finally figured it out!

There was nothing wrong with spark itself – the issue was actually due to a bug (SPARK-42613), which was fixed in version 3.5.0. With this fix, execution now behaves as expected. Since I wasn't explicitly setting spark.task.cpus, it defaulted to 1, which in turn limited the number of cores available to SHAP. After setting spark.task.cpus to the number of cores, the execution time per iteration matched the expected performance!

Thanks!


Solution

  • There was nothing wrong with spark itself – the issue was actually due to a bug (SPARK-42613), which was fixed in version 3.5.0. With this fix, execution now behaves as expected. Since I wasn't explicitly setting spark.task.cpus, it defaulted to 1, which in turn limited the number of cores available to SHAP. After setting spark.task.cpus to the number of cores, the execution time per iteration matched the expected performance!