sqlpysparkapache-spark-sqlamazon-redshift

Calculate running sum in Spark SQL


I am working on a logic where I need to calculate totalscan, last5dayscan, month2dayscan from dailyscan count. As of today I sum the dailyscan count daily but now data volume is making it tough for compute. As a new approach, I am thinking of using a running sum but I am not able figure out how do I calculate running sum on totalscan i.e. today's total scan will be - last totalscan value + today's scan count (where last totalscan can also be 1 month or 2 month back)

ProcessName          DailyScan.    Date
 NewInsurance.         8000        04/12/2024
 InsuranceRenewal.      4500.       04/12/2024
 Fraud Detection.       28.           04/12/2024
 Policy Withdrawn.      100.          04/01/2024
 NewInsurance.           2100.         04/13/2024
 New Insurance           400           04/14/2024
InsuranceRenewal         500           04/14/2024
InsuranceRenewal         500           04/18/2024
New Insurance           500           04/18/2024

Required Output - Let's Assume I execute the query on 04/18/2024

ProcessName   TotalScan Last5DayScan  Month2DayScan   DailyScan     Date 
NewInsurance    8000     8000          8000              8000        04/12/2024
NewInsurance    10100     10100        10100             2100        04/13/2024
NewInsurance    10500     10500        10500             400         04/14/2024
NewInsurance    11000      900         11000             500         04/18/2024

I am doing sum(dailyscan) everyday on entire dataset (after joining source table with calendar table and grouping by ProcessName and CalendarDate) to get TotalScan. This get's me the output but I am sure there will be a better and efficient way to do this. Any thoughts?


Solution

  • For this, you have to partition the data by the column "ProcessName" using Window.partitionBy, then sort/order by "Date" and finally sum over this partitioned window.

    Make sure to convert string date to date type before these operations.

    from pyspark.sql import Window
    from pyspark.sql import functions as F
    
    columns=["ProcessName", "DailyScan", "Date"]
    data=[
        ("NewInsurance","8000","04/12/2024"),
        ("InsuranceRenewal","4500","04/12/2024"),
        ("FraudDetection","28","04/12/2024"),
        ("PolicyWithdrawn","100","04/01/2024"),
        ("NewInsurance","2100","04/13/2024"),
        ("NewInsurance","400","04/14/2024"),
        ("InsuranceRenewal","500","04/14/2024"),
        ("InsuranceRenewal","500","04/18/2024"),
        ("NewInsurance","500","04/18/2024"),
    ]
    
    df = spark.createDataFrame(data, columns)
    
    df = df.withColumn("Date", F.to_date("Date", "MM/dd/yyyy"))
    
    w = Window.partitionBy("ProcessName").orderBy("Date")
    
    df = df.withColumn("DailyScan", F.sum("DailyScan").over(w))
    

    Output:

    +----------------+---------+----------+
    |ProcessName     |DailyScan|Date      |
    +----------------+---------+----------+
    |FraudDetection  |28.0     |2024-04-12|
    |InsuranceRenewal|4500.0   |2024-04-12|
    |InsuranceRenewal|5000.0   |2024-04-14|
    |InsuranceRenewal|5500.0   |2024-04-18|
    |NewInsurance    |8000.0   |2024-04-12|
    |NewInsurance    |10100.0  |2024-04-13|
    |NewInsurance    |10500.0  |2024-04-14|
    |NewInsurance    |11000.0  |2024-04-18|
    |PolicyWithdrawn |100.0    |2024-04-01|
    +----------------+---------+----------+