apache-sparkdatepysparkapache-spark-sqlwindow-functions

Spark Window Functions - rangeBetween dates


I have a Spark SQL DataFrame with date column, and what I'm trying to get is all the rows preceding current row in a given date range. So for example I want to have all the rows from 7 days back preceding given row. I figured out, I need to use a Window Function like:

Window \
    .partitionBy('id') \
    .orderBy('start')

I want to have a rangeBetween 7 days, but there is nothing in the Spark docs I could find on this. Does Spark even provide such option? For now I'm just getting all the preceding rows with:

.rowsBetween(-sys.maxsize, 0)

but would like to achieve something like:

.rangeBetween("7 days", 0)

Solution

  • Spark >= 2.3

    Since Spark 2.3 it is possible to use interval objects using SQL API, but the DataFrame API support is still work in progress.

    df.createOrReplaceTempView("df")
    
    spark.sql(
        """SELECT *, mean(some_value) OVER (
            PARTITION BY id 
            ORDER BY CAST(start AS timestamp) 
            RANGE BETWEEN INTERVAL 7 DAYS PRECEDING AND CURRENT ROW
         ) AS mean FROM df""").show()
    
    ## +---+----------+----------+------------------+       
    ## | id|     start|some_value|              mean|
    ## +---+----------+----------+------------------+
    ## |  1|2015-01-01|      20.0|              20.0|
    ## |  1|2015-01-06|      10.0|              15.0|
    ## |  1|2015-01-07|      25.0|18.333333333333332|
    ## |  1|2015-01-12|      30.0|21.666666666666668|
    ## |  2|2015-01-01|       5.0|               5.0|
    ## |  2|2015-01-03|      30.0|              17.5|
    ## |  2|2015-02-01|      20.0|              20.0|
    ## +---+----------+----------+------------------+
    

    Spark < 2.3

    As far as I know it is not possible directly neither in Spark nor Hive. Both require ORDER BY clause used with RANGE to be numeric. The closest thing I found is conversion to timestamp and operating on seconds. Assuming start column contains date type:

    from pyspark.sql import Row
    
    row = Row("id", "start", "some_value")
    df = sc.parallelize([
        row(1, "2015-01-01", 20.0),
        row(1, "2015-01-06", 10.0),
        row(1, "2015-01-07", 25.0),
        row(1, "2015-01-12", 30.0),
        row(2, "2015-01-01", 5.0),
        row(2, "2015-01-03", 30.0),
        row(2, "2015-02-01", 20.0)
    ]).toDF().withColumn("start", col("start").cast("date"))
    

    A small helper and window definition:

    from pyspark.sql.window import Window
    from pyspark.sql.functions import mean, col
    
    
    # Hive timestamp is interpreted as UNIX timestamp in seconds*
    days = lambda i: i * 86400 
    

    Finally query:

    w = (Window()
       .partitionBy(col("id"))
       .orderBy(col("start").cast("timestamp").cast("long"))
       .rangeBetween(-days(7), 0))
    
    df.select(col("*"), mean("some_value").over(w).alias("mean")).show()
    
    ## +---+----------+----------+------------------+
    ## | id|     start|some_value|              mean|
    ## +---+----------+----------+------------------+
    ## |  1|2015-01-01|      20.0|              20.0|
    ## |  1|2015-01-06|      10.0|              15.0|
    ## |  1|2015-01-07|      25.0|18.333333333333332|
    ## |  1|2015-01-12|      30.0|21.666666666666668|
    ## |  2|2015-01-01|       5.0|               5.0|
    ## |  2|2015-01-03|      30.0|              17.5|
    ## |  2|2015-02-01|      20.0|              20.0|
    ## +---+----------+----------+------------------+
    

    Far from pretty but works.


    * Hive Language Manual, Types