pythonapache-sparkpysparkpyarrow

PySpark ArrayType usage in transformWithStateInPandas state causes java.lang.IllegalArgumentException


I have the following python code that uses PySpark to mock a fraud detection system for credit cards:

from pyspark.sql import SparkSession
from pyspark.sql.functions import from_json, col, unix_timestamp
from pyspark.sql.types import (
    StructType,
    StructField,
    StringType,
    DoubleType,
    TimestampType,
    ArrayType,
    LongType,
)
from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle

import pandas as pd


class MySP(StatefulProcessor):
    def init(self, handle: StatefulProcessorHandle):
        self._state = handle.getValueState(
            "state",
            StructType([
                StructField("locations", ArrayType(StringType())),
                StructField("timestamps", ArrayType(LongType())),
            ])
        )

    def handleInputRows(self, key, rows, timerValues):
        if not self._state.exists():
            current_state = {"locations": [], "timestamps": []}
        else:
            current_state = {"locations": self._state.get()[0], "timestamps": self._state.get()[1]}

        new_locations = []
        new_timestamps = []
        for pdf in rows:
            new_locations.extend(pdf["location"].tolist())
            new_timestamps.extend(pdf["unix_timestamp"].tolist())

        current_state["locations"].extend(new_locations)
        current_state["timestamps"].extend(new_timestamps)

      
        self._state.update((current_state["locations"], current_state["timestamps"]))

        yield pd.DataFrame()


def main():
    spark = (
        SparkSession.builder.appName("RealTimeFraudDetector")
        .config(
            "spark.sql.streaming.stateStore.providerClass",
            "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider",
        )
        .getOrCreate()
    )
    spark.sparkContext.setLogLevel("WARN")

    schema = StructType(
        [
            StructField("transaction_id", StringType(), True),
            StructField("card_number", StringType(), True),
            StructField("card_holder", StringType(), True),
            StructField("amount", DoubleType(), True),
            StructField("currency", StringType(), True),
            StructField("location", StringType(), True),
            StructField("timestamp", TimestampType(), True),
        ]
    )

    output_schema = StructType(
        [
            StructField("card_number", StringType(), True),
            StructField("is_fraud", StringType(), True),
            StructField("message", StringType(), True),
        ]
    )

    kafka_df = (
        spark.readStream.format("kafka")
        .option("kafka.bootstrap.servers", "broker:29092")
        .option("subscribe", "transaction")
        .load()
    )

    transaction_df = (
        kafka_df.select(from_json(col("value").cast("string"), schema).alias("data"))
        .select("data.*")
        .withColumn("unix_timestamp", unix_timestamp(col("timestamp")))
    )

    filtered_df = (
        transaction_df.withWatermark("timestamp", "10 minutes")
        .groupBy("card_number")
        .transformWithStateInPandas(
            MySP(), outputStructType=output_schema, outputMode="append", timeMode="None"
        )
    )

    query = filtered_df.writeStream.outputMode("append").format("console").start()

    query.awaitTermination()


if __name__ == "__main__":
    main()

After the first batch is processed, and it starts processing the second batch, I get the following error:

consumer  | 25/09/05 07:27:40 WARN TaskSetManager: Lost task 128.0 in stage 5.0 (TID 530) (987409f72424 executor driver): TaskKilled (Stage cancelled: Job aborted due to stage failure: Task 120 in stage 5.0 failed 1 times, most recent failure: Lost task 120.0 in stage 5.0 (TID 522) (987409f72424 executor driver): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
consumer  |   File "/opt/bitnami/spark/python/lib/pyspark.zip/pyspark/worker.py", line 2044, in main
consumer  |     process()
consumer  |   File "/opt/bitnami/spark/python/lib/pyspark.zip/pyspark/worker.py", line 2036, in process
consumer  |     serializer.dump_stream(out_iter, outfile)
consumer  |   File "/opt/bitnami/spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 1236, in dump_stream
consumer  |     super().dump_stream(flatten_iterator(), stream)
consumer  |   File "/opt/bitnami/spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 564, in dump_stream
consumer  |     return ArrowStreamSerializer.dump_stream(self, init_stream_yield_batches(), stream)
consumer  |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
consumer  |   File "/opt/bitnami/spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 115, in dump_stream
consumer  |     for batch in iterator:
consumer  |                  ^^^^^^^^
consumer  |   File "/opt/bitnami/spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 557, in init_stream_yield_batches
consumer  |     for series in iterator:
consumer  |                   ^^^^^^^^
consumer  |   File "/opt/bitnami/spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 1233, in flatten_iterator
consumer  |     for pdf in iter_pdf:
consumer  |                ^^^^^^^^
consumer  |   File "/app/detector.py", line 41, in handleInputRows
consumer  |     current_state = {"locations": self._state.get()[0], "timestamps": self._state.get()[1]}
consumer  |                                   ^^^^^^^^^^^^^^^^^
consumer  |   File "/opt/bitnami/spark/python/lib/pyspark.zip/pyspark/sql/streaming/stateful_processor.py", line 62, in get
consumer  |     return self._valueStateClient.get(self._stateName)
consumer  |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
consumer  |   File "/opt/bitnami/spark/python/lib/pyspark.zip/pyspark/sql/streaming/value_state_client.py", line 78, in get
consumer  |     raise PySparkRuntimeError(f"Error getting value state: " f"{response_message[1]}")
consumer  | pyspark.errors.exceptions.base.PySparkRuntimeError: Error getting value state: couldn't introspect javabean: java.lang.IllegalArgumentException: wrong number of arguments
consumer  | 
consumer  |     at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:581)
consumer  |     at org.apache.spark.sql.execution.python.PythonArrowOutput$$anon$1.read(PythonArrowOutput.scala:117)
consumer  |     at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:532)
consumer  |     at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
consumer  |     at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:601)
consumer  |     at scala.collection.Iterator$$anon$9.hasNext(Iterator.scala:583)
consumer  |     at org.apache.spark.util.CompletionIterator.hasNext(CompletionIterator.scala:31)
consumer  |     at scala.collection.Iterator$$anon$9.hasNext(Iterator.scala:583)
consumer  |     at org.apache.spark.sql.execution.datasources.v2.WritingSparkTask$IteratorWithMetrics.hasNext(WriteToDataSourceV2Exec.scala:545)
consumer  |     at org.apache.spark.sql.connector.write.DataWriter.writeAll(DataWriter.java:107)
consumer  |     at org.apache.spark.sql.execution.streaming.sources.PackedRowDataWriter.writeAll(PackedRowWriterFactory.scala:53)
consumer  |     at org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTask$.write(WriteToDataSourceV2Exec.scala:587)
consumer  |     at org.apache.spark.sql.execution.datasources.v2.WritingSparkTask.$anonfun$run$5(WriteToDataSourceV2Exec.scala:483)
consumer  |     at org.apache.spark.util.Utils$.tryWithSafeFinallyAndFailureCallbacks(Utils.scala:1323)
consumer  |     at org.apache.spark.sql.execution.datasources.v2.WritingSparkTask.run(WriteToDataSourceV2Exec.scala:535)
consumer  |     at org.apache.spark.sql.execution.datasources.v2.WritingSparkTask.run$(WriteToDataSourceV2Exec.scala:466)
consumer  |     at org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTask$.run(WriteToDataSourceV2Exec.scala:584)
consumer  |     at org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec.$anonfun$writeWithV2$2(WriteToDataSourceV2Exec.scala:427)
consumer  |     at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
consumer  |     at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:171)
consumer  |     at org.apache.spark.scheduler.Task.run(Task.scala:147)
consumer  |     at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$5(Executor.scala:647)
consumer  |     at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:80)
consumer  |     at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:77)
consumer  |     at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:99)
consumer  |     at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:650)
consumer  |     at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(Unknown Source)
consumer  |     at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(Unknown Source)
consumer  |     at java.base/java.lang.Thread.run(Unknown Source)

There must be an error in how I save the data into the state, since it seems that PyArrow cannot deserialize it, but I know that I can save ArrayType of LongType and StringType from this article What could be the cause? I'm pretty new to PySpark, I was developing this project to learn how to use it, but I've been stuck on this for days trying multiple solutions to no avail.


Solution

  • I've resorted to implement a workaround by using two list states:

    class MySP(StatefulProcessor):
        def init(self, handle: StatefulProcessorHandle):
            list_timestamp_schema = StructType([StructField("timestamp", LongType(), True)])
            list_location_schema = StructType([StructField("location", StringType(), True)])
            self._timestamp_state = handle.getListState(stateName="timestampState", schema=list_timestamp_schema)
            self._location_state = handle.getListState(stateName="locationState", schema=list_location_schema)
    

    This way I can save and load the state without deserialization errors