pythonmultithreadingpysparkpython-multithreadingspark-structured-streaming

How to Gracefully Stop a Thread Inside a Spark foreachBatch Callback


I'm using a thread from the threading package to launch a function that performs spark streaming. I want to stop the thread inside the process function when a condition is met.

import threading
import asyncio
from pyspark.sql import SparkSession

threading.Thread(target=streaming_to_consumer).start()

async def process(df, df_id):
    if df_id == 2:
        # I want to stop the thread here

async def streaming_to_consumer():
    df = spark.readStream \
        .format("iceberg") \
        .load("local.db.table")

    query = df \
        .writeStream \
        .outputMode("append") \
        .foreachBatch(process) \
        .trigger(processingTime="0.5 seconds") \
        .start()

    query.awaitTermination(2)


Solution

  • I resolved the problem by adding event = threading.Event() and passing event as a parameter to process. Meanwhile, I put the thread that launches the streaming thread in a while loop, waiting for the event. After df_id == 2, the thread receives the event, exits the loop and stops the streaming thread by query.stop().

    import threading
    import asyncio
    import functools
    from pyspark.sql import SparkSession
    
    threading.Thread(target=streaming_to_consumer).start()
    
    async def process(df, df_id, event):
        if df_id == 2:
            event.set()
            return
    
    def process_wrapper(df, df_id, event):
        asyncio.run(process(df, df_id, event))
    
    async def streaming_to_consumer():
        df = spark.readStream \
            .format("iceberg") \
            .load("local.db.table")
    
        event = threading.Event()
    
        query = df \
            .writeStream \
            .outputMode("append") \
            .foreachBatch(functools.partial(process_wrapper, event)) \
            .trigger(processingTime="0.5 seconds") \
            .start()
    
        query.awaitTermination(2)
    
        while not event.is_set():
            time.sleep(1)
    
        query.stop()
    
    

    Note: I used functools.partial and wrappers to pass event as an extra argument to process.