pyspark

How to find the most recent value (by date) for many people and many columns Pyspark?


This is the data that I have, including many people with different id, date, price1, price2, ...pricex, ...

(Note: the date column is not sorted, I sort my example so you can follow it easier).

from datetime import date
import findspark
findspark.init()
import pyspark
from pyspark.sql import SparkSession
spark = SparkSession.builder.master("local").appName('practice').getOrCreate()
from pyspark.sql import functions as F

rdd = spark.sparkContext.parallelize([
    [1, date(2016, 1, 7) , 10.0 , None ],
    [1, date(2016, 1, 8),  14.50, None],
    [1, date(2016, 1, 16), 14.50, None],
    [1, date(2016, 2, 7) , 13.90, None],
    [1, date(2016, 3, 12),  None, None],

    #
    [2, date(2016, 1, 9) ,  None, 23.0],
    [2, date(2016, 1, 17) , None, 21.0],
    [2, date(2016, 1, 28),  5.50, None],
    [2, date(2016, 1, 28),  None, None],
    #
    [3, date(2016, 1, 5) ,  12.0  , None],
    [3, date(2016, 1, 6) , None , 13.9]
])

df = rdd.toDF(['id','date','price1', 'price2'])
df.show()
+---+----------+------+------+
| id|      date|price1|price2|
+---+----------+------+------+
|  1|2016-01-07|  10.0|  null|
|  1|2016-01-08|  14.5|  null|
|  1|2016-01-16|  14.5|  null|
|  1|2016-02-07|  13.9|  null|
|  1|2016-03-12|  null|  null|
|  2|2016-01-09|  null|  23.0|
|  2|2016-01-17|  null|  21.0|
|  2|2016-01-28|   5.5|  null|
|  2|2016-01-28|  null|  null|
|  3|2016-01-05|  12.0|  null|
|  3|2016-01-06|  null|  13.9|
+---+----------+------+------+

I need to have the most recent available value for the price1, price2, ..., pricex for each id. To be more specific, this is what I need to have

+---+------+------+
| id|price1|price2|
+---+------+------+
|  1|  13.9|  null|    # most recent price1 in 2016-02-07, no price2 value available for id=1
|  2|   5.5|  21.0|    # most recent price1 in 2016-01-28, most recent price2 in 2016-01-17
|  3|  12.0|  13.9|
+---+------+------+

I have tried to arrange the data by date for each person, but this only helps me to have the most recent date and its values

df.withColumn("row_number",F.row_number().over(Window.partitionBy(df.id).orderBy(df.date.desc()))).filter(F.col("row_number")==1).show()

This is what I have, but it is NOT what I want to have

+---+----------+------+------+----------+
| id|      date|price1|price2|row_number|
+---+----------+------+------+----------+
|  1|2016-03-12|  null|  null|         1|
|  2|2016-01-28|   5.5|  null|         1|
|  3|2016-01-06|  null|  13.9|         1|
+---+----------+------+------+----------+

Could you please help me with this problem? My data has many price1, price2, ..., pricex, ... columns. Is there any way that can work without too much code? Thanks.


Solution

  • final_df = (

      #Sort and forward fill
      df.select('id','date',*[F.coalesce(F.last(x, True).over(w1), F.first(x,True).over(w2)).alias(x) for x in df.select('price1', 'price2').columns])
      .withColumn('x', max('date').over(Window.partitionBy('id')))#Create a sortcolumn with max date for each id
      .where(col('date')==col('x'))#Filter where column above equal to date
      .dropDuplicates()#Drop duplicates
      .drop('x')#drop the sort column creates
    ).show()
    
    
    +---+----------+------+------+
    | id|      date|price1|price2|
    +---+----------+------+------+
    |  1|2016-03-12|  13.9|  null|
    |  2|2016-01-28|   5.5|  21.0|
    |  3|2016-01-06|  12.0|  13.9|
    +---+----------+------+------+