I have the following python code that uses PySpark to mock a fraud detection system for credit cards:
from pyspark.sql import SparkSession
from pyspark.sql.functions import from_json, col, unix_timestamp
from pyspark.sql.types import (
StructType,
StructField,
StringType,
DoubleType,
TimestampType,
ArrayType,
LongType,
)
from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle
import pandas as pd
class MySP(StatefulProcessor):
def init(self, handle: StatefulProcessorHandle):
self._state = handle.getValueState(
"state",
StructType([
StructField("locations", ArrayType(StringType())),
StructField("timestamps", ArrayType(LongType())),
])
)
def handleInputRows(self, key, rows, timerValues):
if not self._state.exists():
current_state = {"locations": [], "timestamps": []}
else:
current_state = {"locations": self._state.get()[0], "timestamps": self._state.get()[1]}
new_locations = []
new_timestamps = []
for pdf in rows:
new_locations.extend(pdf["location"].tolist())
new_timestamps.extend(pdf["unix_timestamp"].tolist())
current_state["locations"].extend(new_locations)
current_state["timestamps"].extend(new_timestamps)
self._state.update((current_state["locations"], current_state["timestamps"]))
yield pd.DataFrame()
def main():
spark = (
SparkSession.builder.appName("RealTimeFraudDetector")
.config(
"spark.sql.streaming.stateStore.providerClass",
"org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider",
)
.getOrCreate()
)
spark.sparkContext.setLogLevel("WARN")
schema = StructType(
[
StructField("transaction_id", StringType(), True),
StructField("card_number", StringType(), True),
StructField("card_holder", StringType(), True),
StructField("amount", DoubleType(), True),
StructField("currency", StringType(), True),
StructField("location", StringType(), True),
StructField("timestamp", TimestampType(), True),
]
)
output_schema = StructType(
[
StructField("card_number", StringType(), True),
StructField("is_fraud", StringType(), True),
StructField("message", StringType(), True),
]
)
kafka_df = (
spark.readStream.format("kafka")
.option("kafka.bootstrap.servers", "broker:29092")
.option("subscribe", "transaction")
.load()
)
transaction_df = (
kafka_df.select(from_json(col("value").cast("string"), schema).alias("data"))
.select("data.*")
.withColumn("unix_timestamp", unix_timestamp(col("timestamp")))
)
filtered_df = (
transaction_df.withWatermark("timestamp", "10 minutes")
.groupBy("card_number")
.transformWithStateInPandas(
MySP(), outputStructType=output_schema, outputMode="append", timeMode="None"
)
)
query = filtered_df.writeStream.outputMode("append").format("console").start()
query.awaitTermination()
if __name__ == "__main__":
main()
After the first batch is processed, and it starts processing the second batch, I get the following error:
consumer | 25/09/05 07:27:40 WARN TaskSetManager: Lost task 128.0 in stage 5.0 (TID 530) (987409f72424 executor driver): TaskKilled (Stage cancelled: Job aborted due to stage failure: Task 120 in stage 5.0 failed 1 times, most recent failure: Lost task 120.0 in stage 5.0 (TID 522) (987409f72424 executor driver): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
consumer | File "/opt/bitnami/spark/python/lib/pyspark.zip/pyspark/worker.py", line 2044, in main
consumer | process()
consumer | File "/opt/bitnami/spark/python/lib/pyspark.zip/pyspark/worker.py", line 2036, in process
consumer | serializer.dump_stream(out_iter, outfile)
consumer | File "/opt/bitnami/spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 1236, in dump_stream
consumer | super().dump_stream(flatten_iterator(), stream)
consumer | File "/opt/bitnami/spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 564, in dump_stream
consumer | return ArrowStreamSerializer.dump_stream(self, init_stream_yield_batches(), stream)
consumer | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
consumer | File "/opt/bitnami/spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 115, in dump_stream
consumer | for batch in iterator:
consumer | ^^^^^^^^
consumer | File "/opt/bitnami/spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 557, in init_stream_yield_batches
consumer | for series in iterator:
consumer | ^^^^^^^^
consumer | File "/opt/bitnami/spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 1233, in flatten_iterator
consumer | for pdf in iter_pdf:
consumer | ^^^^^^^^
consumer | File "/app/detector.py", line 41, in handleInputRows
consumer | current_state = {"locations": self._state.get()[0], "timestamps": self._state.get()[1]}
consumer | ^^^^^^^^^^^^^^^^^
consumer | File "/opt/bitnami/spark/python/lib/pyspark.zip/pyspark/sql/streaming/stateful_processor.py", line 62, in get
consumer | return self._valueStateClient.get(self._stateName)
consumer | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
consumer | File "/opt/bitnami/spark/python/lib/pyspark.zip/pyspark/sql/streaming/value_state_client.py", line 78, in get
consumer | raise PySparkRuntimeError(f"Error getting value state: " f"{response_message[1]}")
consumer | pyspark.errors.exceptions.base.PySparkRuntimeError: Error getting value state: couldn't introspect javabean: java.lang.IllegalArgumentException: wrong number of arguments
consumer |
consumer | at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:581)
consumer | at org.apache.spark.sql.execution.python.PythonArrowOutput$$anon$1.read(PythonArrowOutput.scala:117)
consumer | at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:532)
consumer | at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
consumer | at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:601)
consumer | at scala.collection.Iterator$$anon$9.hasNext(Iterator.scala:583)
consumer | at org.apache.spark.util.CompletionIterator.hasNext(CompletionIterator.scala:31)
consumer | at scala.collection.Iterator$$anon$9.hasNext(Iterator.scala:583)
consumer | at org.apache.spark.sql.execution.datasources.v2.WritingSparkTask$IteratorWithMetrics.hasNext(WriteToDataSourceV2Exec.scala:545)
consumer | at org.apache.spark.sql.connector.write.DataWriter.writeAll(DataWriter.java:107)
consumer | at org.apache.spark.sql.execution.streaming.sources.PackedRowDataWriter.writeAll(PackedRowWriterFactory.scala:53)
consumer | at org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTask$.write(WriteToDataSourceV2Exec.scala:587)
consumer | at org.apache.spark.sql.execution.datasources.v2.WritingSparkTask.$anonfun$run$5(WriteToDataSourceV2Exec.scala:483)
consumer | at org.apache.spark.util.Utils$.tryWithSafeFinallyAndFailureCallbacks(Utils.scala:1323)
consumer | at org.apache.spark.sql.execution.datasources.v2.WritingSparkTask.run(WriteToDataSourceV2Exec.scala:535)
consumer | at org.apache.spark.sql.execution.datasources.v2.WritingSparkTask.run$(WriteToDataSourceV2Exec.scala:466)
consumer | at org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTask$.run(WriteToDataSourceV2Exec.scala:584)
consumer | at org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec.$anonfun$writeWithV2$2(WriteToDataSourceV2Exec.scala:427)
consumer | at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
consumer | at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:171)
consumer | at org.apache.spark.scheduler.Task.run(Task.scala:147)
consumer | at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$5(Executor.scala:647)
consumer | at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:80)
consumer | at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:77)
consumer | at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:99)
consumer | at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:650)
consumer | at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(Unknown Source)
consumer | at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(Unknown Source)
consumer | at java.base/java.lang.Thread.run(Unknown Source)
There must be an error in how I save the data into the state, since it seems that PyArrow cannot deserialize it, but I know that I can save ArrayType of LongType and StringType from this article What could be the cause? I'm pretty new to PySpark, I was developing this project to learn how to use it, but I've been stuck on this for days trying multiple solutions to no avail.
I've resorted to implement a workaround by using two list states:
class MySP(StatefulProcessor):
def init(self, handle: StatefulProcessorHandle):
list_timestamp_schema = StructType([StructField("timestamp", LongType(), True)])
list_location_schema = StructType([StructField("location", StringType(), True)])
self._timestamp_state = handle.getListState(stateName="timestampState", schema=list_timestamp_schema)
self._location_state = handle.getListState(stateName="locationState", schema=list_location_schema)
This way I can save and load the state without deserialization errors