pysparkspark-structured-streamingmoving-average

Calculating a moving average column using pyspark structured streaming


I'm using pyspark to process some streaming data coming in and I want to add a new column to my data frame with a 50-second moving average.

i tried using a Window spec with rangeBetween:

import pyspark.sql.window as W

w = (W.Window()
     .partitionBy(col("sender"))
     .orderBy(F.col("event_time").cast('long'))
     .rangeBetween(-50, 0))
df2 = df.withColumn('rolling_average', F.avg("fr").over(w))

But this gives me an error, as structured streaming requires a time-based window (probably to manage state):

AnalysisException: Non-time-based windows are not supported on streaming DataFrames/Datasets

Using the sql.window function i can also calculate the a moving average, but this will give me the results by grouping on a window (and unique id key called sender) that uses a tumbling (or hopping) window:

df.select('sender', 'event_time', 'fr').groupBy("sender", window("event_time", "50 second")).avg().alias('avg_fr')
sender window avg(fr)
59834cfd-6cb2-4ece-8353-0a9b20389656 {"start":"2021-04-12T09:57:30.000+0000","end":"2021-04-12T09:58:20.000+0000"} 0.17443667352199554
8b5d90b9-65d9-4dd2-b742-31c4f0ce37d6 {"start":"2021-04-12T09:57:30.000+0000","end":"2021-04-12T09:58:20.000+0000"} 0.010564474388957024
a74204f3-e25d-4737-a302-9206cd69e90a {"start":"2021-04-12T09:57:30.000+0000","end":"2021-04-12T09:58:20.000+0000"} 0.16375258564949036
db16426d-a9ba-449b-9777-3bdfadf0e0d9 {"start":"2021-04-12T09:57:30.000+0000","end":"2021-04-12T09:58:20.000+0000"} 0.17516431212425232

The tumbling window is obviously not what I want and I would need to somehow join this to the original table again. I'm not sure how to define a sliding window based on the irregular event timestamps coming in.

Right now I think about writing a stateful function that stores a set of the previously received records into a state and updating that for each new data point coming in. But this seems quite elaborate for such a common activity that I expect can be done in an easier way.

EDIT: current version of Spark (3.1.1) only allows arbitrary stateful functions to be built in Java or Scala, not python, to safeguard the conversion to JVM.

Any thoughts if this is actually the correct way to go?


Solution

  • For the streaming version i'm doing basically the same thing as for the solution i posted for the non-streaming solution.

    schema = StructType([ StructField("ts", TimestampType(), True), StructField("sender", StringType(), True), StructField("value1", LongType(), True), StructField("value2", FloatType(), True) ])
    
    df = spark.readStream.schema(schema).format("csv").load("dbfs:/FileStore/shared_uploads/joris.vanagtmaal@wartsilaazure.com/raw_data*")
    
    df=df.withWatermark("ts", "2 minutes").select('*',round_time(df["ts"]).alias("trunct_time"))
    avgDF = df.withWatermark("ts", "2 minutes").select('value1','sender','ts').groupBy("sender", window("ts", "50 second", '10 second')).avg()
    avgDF = avgDF.withColumn("window_end", avgDF.window.end).withColumnRenamed('sender', 'sender2').withWatermark("window_end", "2 minutes")
    
    joined_stream=df.join(
      avgDF,
      expr("""
        trunct_time = window_end AND
        sender = sender2 AND   
        """),
      "leftOuter"                
    )
    
    query = (
      joined_stream
        .writeStream
        .format("memory")        # memory = store in-memory table (for testing only)
        .queryName("joined")     # joined = name of the in-memory table
        .outputMode("append")  # append = allows stream on stream joins
        .start()
    )
    

    This results in the following error:

    AnalysisException: Detected pattern of possible 'correctness' issue due to global watermark. The query contains stateful operation which can emit rows older than the current watermark plus allowed late record delay, which are "late rows" in downstream stateful operations and these rows can be discarded. Please refer the programming guide doc for more details. If you understand the possible risk of correctness issue and still need to run the query, you can disable this check by setting the config `spark.sql.streaming.statefulOperator.checkCorrectness.enabled` to false.;
    

    The documentation mentions:

    Any of the stateful operation(s) after any of below stateful operations can have this issue:

    streaming aggregation in Append mode or stream-stream outer join

    There’s a known workaround: split your streaming query into multiple queries per stateful operator, and ensure end-to-end exactly once per query. Ensuring end-to-end exactly once for the last query is optional.

    But this is quite a cryptic description on how to solve this issue.Based on: https://issues.apache.org/jira/browse/SPARK-28074:

    It means split the queries into multiple steps with 1 stateful operation each and persist the intermediate results to topics. This produces mostly reproducible results. But of course it increases the overall delay of the messages passing through.

    Depending on the setting this may or may not be the correct solution, but for this example i decided setting the check correctness parameter to false, so it will no longer throw an exception and only write a warning in the logs.

    %sql set spark.sql.streaming.statefulOperator.checkCorrectness.enabled=False
    

    Now it will give me the result i wanted to get:

    %sql select * from joined
    | ts                           | sender | value1 | value2 | avg(value1) | 
    |------------------------------|--------|--------|--------|-------------| 
    | 2021-04-15T14:33:16.000+0000 | B      | 10     | 200    | 10.5        | 
    | 2021-04-15T14:32:47.000+0000 | B      | 11     | -500   | 9.5         | 
    | 2021-04-15T14:31:45.000+0000 | A      | 1      | 4      | 1           | 
    | 2021-04-15T14:32:16.000+0000 | A      | 2      | -3     | 1.5         | 
    | 2021-04-15T14:32:46.000+0000 | A      | 3      | 5      | 2.5         | 
    | 2021-04-15T14:33:17.000+0000 | A      | 0      | 2      | 1.5         | 
    | 2021-04-15T14:32:16.000+0000 | B      | 8      | 100    | 8           | 
    

    /* One more caveat, these results only become visible if they are followed by a new datapoint that moves the watermark beyond the threshold (here 2 minutes), which in a streaming application would not be an issue, but for this example i've added a new 8th datapoint a couple of minutes later, which of course isn't visible in the output for the same reason.