pandasapache-sparkpysparkwindowing

Clean way to identify runs on a PySpark DF ArrayType column


Given a PySpark DataFrame of the form:

+----+--------+
|time|messages|
+----+--------+
| t01|    [m1]|
| t03|[m1, m2]|
| t04|    [m2]|
| t06|    [m3]|
| t07|[m3, m1]|
| t08|    [m1]|
| t11|    [m2]|
| t13|[m2, m4]|
| t15|    [m2]|
| t20|    [m4]|
| t21|      []|
| t22|[m1, m4]|
+----+--------+

I'd like to refactor it to compress runs containing the same message (the order of the output doesn't matter much, but sorted her for clarity):

+----------+--------+-------+
|start_time|end_time|message|
+----------+--------+-------+
|       t01|     t03|     m1|
|       t07|     t08|     m1|
|       t22|     t22|     m1|
|       t03|     t04|     m2|
|       t11|     t15|     m2|
|       t06|     t07|     m3|
|       t13|     t13|     m4|
|       t20|     t20|     m4|
|       t22|     t22|     m4|
+----------+--------+-------+

(i.e. treat the message column as a sequence and identify the start and end of "runs" for each message),

Is there a clean way to make this transformation in Spark? Currently, I'm dumping this as a 6 GB TSV and processing it imperatively.

I'm open to the possibility of toPandas-ing this and accumulating on the driver if Pandas has a clean way to do this aggregation.

(see my answer below for a naïve baseline implementation).


Solution

  • You can try the following method using forward-filling(Spark 2.4+ is not required):

    Step-1: do the following:

    1. for each row ordered by time, find prev_messages and next_messages
    2. explode messages into individual message
    3. for each message, if prev_messages is NULL or message is not in prev_messages, then set start=time, see below SQL syntax:

      IF(prev_messages is NULL or !array_contains(prev_messages, message),time,NULL)
      

      which can be simplified to:

      IF(array_contains(prev_messages, message),NULL,time)
      
    4. and if next_messages is NULL or message is not in next_messages, then set end=time

    Code below:

    from pyspark.sql import Window, functions as F
    
    # rows is defined in your own post
    df = spark.createDataFrame(rows, ['time', 'messages'])
    
    w1 = Window.partitionBy().orderBy('time')
    
    df1 = df.withColumn('prev_messages', F.lag('messages').over(w1)) \
        .withColumn('next_messages', F.lead('messages').over(w1)) \
        .withColumn('message', F.explode('messages')) \
        .withColumn('start', F.expr("IF(array_contains(prev_messages, message),NULL,time)")) \
        .withColumn('end', F.expr("IF(array_contains(next_messages, message),NULL,time)"))
    
    df1.show()
    #+----+--------+-------------+-------------+-------+-----+----+
    #|time|messages|prev_messages|next_messages|message|start| end|
    #+----+--------+-------------+-------------+-------+-----+----+
    #| t01|    [m1]|         null|     [m1, m2]|     m1|  t01|null|
    #| t03|[m1, m2]|         [m1]|         [m2]|     m1| null| t03|
    #| t03|[m1, m2]|         [m1]|         [m2]|     m2|  t03|null|
    #| t04|    [m2]|     [m1, m2]|         [m3]|     m2| null| t04|
    #| t06|    [m3]|         [m2]|     [m3, m1]|     m3|  t06|null|
    #| t07|[m3, m1]|         [m3]|         [m1]|     m3| null| t07|
    #| t07|[m3, m1]|         [m3]|         [m1]|     m1|  t07|null|
    #| t08|    [m1]|     [m3, m1]|         [m2]|     m1| null| t08|
    #| t11|    [m2]|         [m1]|     [m2, m4]|     m2|  t11|null|
    #| t13|[m2, m4]|         [m2]|         [m2]|     m2| null|null|
    #| t13|[m2, m4]|         [m2]|         [m2]|     m4|  t13| t13|
    #| t15|    [m2]|     [m2, m4]|         [m4]|     m2| null| t15|
    #| t20|    [m4]|         [m2]|           []|     m4|  t20| t20|
    #| t22|[m1, m4]|           []|         null|     m1|  t22| t22|
    #| t22|[m1, m4]|           []|         null|     m4|  t22| t22|
    #+----+--------+-------------+-------------+-------+-----+----+
    

    Step-2: create WindSpec partitioned by message and do forward-filling to start column.

    w2 = Window.partitionBy('message').orderBy('time')
    
    # for illustration purpose, I used a different column-name so that we can 
    # compare `start` column before and after ffill
    df2 = df1.withColumn('start_new', F.last('start', True).over(w2))
    df2.show()
    #+----+--------+-------------+-------------+-------+-----+----+---------+
    #|time|messages|prev_messages|next_messages|message|start| end|start_new|
    #+----+--------+-------------+-------------+-------+-----+----+---------+
    #| t01|    [m1]|         null|     [m1, m2]|     m1|  t01|null|      t01|
    #| t03|[m1, m2]|         [m1]|         [m2]|     m1| null| t03|      t01|
    #| t07|[m3, m1]|         [m3]|         [m1]|     m1|  t07|null|      t07|
    #| t08|    [m1]|     [m3, m1]|         [m2]|     m1| null| t08|      t07|
    #| t22|[m1, m4]|           []|         null|     m1|  t22| t22|      t22|
    #| t03|[m1, m2]|         [m1]|         [m2]|     m2|  t03|null|      t03|
    #| t04|    [m2]|     [m1, m2]|         [m3]|     m2| null| t04|      t03|
    #| t11|    [m2]|         [m1]|     [m2, m4]|     m2|  t11|null|      t11|
    #| t13|[m2, m4]|         [m2]|         [m2]|     m2| null|null|      t11|
    #| t15|    [m2]|     [m2, m4]|         [m4]|     m2| null| t15|      t11|
    #| t06|    [m3]|         [m2]|     [m3, m1]|     m3|  t06|null|      t06|
    #| t07|[m3, m1]|         [m3]|         [m1]|     m3| null| t07|      t06|
    #| t13|[m2, m4]|         [m2]|         [m2]|     m4|  t13| t13|      t13|
    #| t20|    [m4]|         [m2]|           []|     m4|  t20| t20|      t20|
    #| t22|[m1, m4]|           []|         null|     m4|  t22| t22|      t22|
    #+----+--------+-------------+-------------+-------+-----+----+---------+
    

    Step-3: remove rows having end is NULL and then select only required columns:

    df2.selectExpr("message", "start_new as start", "end") \
        .filter("end is not NULL") \
        .orderBy("message","start").show()
    #+-------+-----+---+
    #|message|start|end|
    #+-------+-----+---+
    #|     m1|  t01|t03|
    #|     m1|  t07|t08|
    #|     m1|  t22|t22|
    #|     m2|  t03|t04|
    #|     m2|  t11|t15|
    #|     m3|  t06|t07|
    #|     m4|  t13|t13|
    #|     m4|  t20|t20|
    #|     m4|  t22|t22|
    #+-------+-----+---+
    

    To summarize the above steps, we have the following:

    from pyspark.sql import Window, functions as F
    
    # define two Window Specs
    w1 = Window.partitionBy().orderBy('time')
    w2 = Window.partitionBy('message').orderBy('time')
    
    df_new = df \
        .withColumn('prev_messages', F.lag('messages').over(w1)) \
        .withColumn('next_messages', F.lead('messages').over(w1)) \
        .withColumn('message', F.explode('messages')) \
        .withColumn('start', F.expr("IF(array_contains(prev_messages, message),NULL,time)")) \
        .withColumn('end', F.expr("IF(array_contains(next_messages, message),NULL,time)")) \
        .withColumn('start', F.last('start', True).over(w2)) \
        .select("message", "start", "end") \
        .filter("end is not NULL")
    
    df_new.orderBy("start").show()