apache-sparkpyspark

Pick a row based on a date or a default


I have two DataFrames:

        df_rates                        df_trades
(rate from currency -> USD)   
+--------+----------+----+    +---+--------+------+----------+    
|currency| rate_date|rate|    | id|currency|amount|trade_date|    
+--------+----------+----+    +---+--------+------+----------+    
|     EUR|2025-01-09|1.19|    |  1|     EUR|  1000|2025-01-09| # exact rate available
|     EUR|2025-01-08|1.18|    |  2|     CAD|  1000|2025-01-09| # 1 day prior rate available
|     CAD|2025-01-08|0.78|    |  3|     AUD|  1000|2025-01-09| # no applicable rate available
|     CAD|2025-01-07|0.77|    |  4|     HKD|  1000|2025-01-09| # no rate available at all
|     AUD|2025-02-09|1.39|    +---+--------+------+----------+    
|     AUD|2025-02-08|1.38|                                              
+--------+----------+----+                                              

For every trade I need to calculate usd_amount by applying appropriate exchange rate. The way to pick rate is:

If no rate is found in this way then usd_amount = null

I have following code that works. But I'm not sure if it'll scale. Specifically for the case of trade_id = 3 (when rates are available but none are in correct date range), because in practice the rates table has 1000s of rates (going back 5-7 years). Part of code marked This PART below.

Is there some other logic with which this can be achieved more efficiently?

from pyspark.sql import SparkSession, Window
from pyspark.sql.functions import col, row_number, date_diff, when

def log_dataframe(df, msg):
    print(msg)
    df.show()

def calc_usd_amount(df_trades, df_rates):
    df = df_trades.join(df_rates, how='left_outer', on='currency').withColumn('date_diff', date_diff('trade_date', 'rate_date'))
    date_diff_no_good = (col('date_diff') < 0) | (col('date_diff') > 7)
    # This PART
    df = (
        df.withColumns({
            'rate_date': when(date_diff_no_good, None).otherwise(col('rate_date')),
            'rate': when(date_diff_no_good, None).otherwise(col('rate')),
        })
        .drop_duplicates(['id', 'rate_date', 'rate'])
    )
    w_spec = row_number().over(Window.partitionBy(col('id'), col('currency')).orderBy(col('rate_date').desc()))
    df = (
        df.filter('rate_date IS NULL OR (rate_date <= trade_date AND rate_date > (trade_date - 7))')
        .withColumn('rate_row_num', w_spec).filter('rate_row_num == 1')
        .withColumn('usd_amount', col('rate') * col('amount'))
    )
    return df.drop('date_diff', 'rate_row_num')

from pyspark.sql import SparkSession
from datetime import date

spark = SparkSession.builder.getOrCreate()
dt = date.fromisoformat
df_trades = spark.createDataFrame(
    data = [
        (1, 'EUR', 1000, dt('2025-01-09')),  # trade date rate available
        (2, 'CAD', 1000, dt('2025-01-09')),  # trade date -1d, rate available
        (3, 'AUD', 1000, dt('2025-01-09')),  # no applicable rate available
        (4, 'HKD', 1000, dt('2025-01-09')),  # no rate available at all
    ],
    schema=['id', 'currency', 'amount', 'trade_date'],
)
df_rates = spark.createDataFrame(
    data = [
        ('EUR', dt('2025-01-09'), 1.19),  # trade date rate available
        ('EUR', dt('2025-01-08'), 1.18),
        ('CAD', dt('2025-01-08'), 0.78),  # trade date -1d, rate available
        ('CAD', dt('2025-01-07'), 0.77),
        ('AUD', dt('2025-02-09'), 1.39),  # no applicable rate available
        ('AUD', dt('2025-02-08'), 1.38),
    ],
    schema=['currency', 'rate_date', 'rate']
)

df_out = calc_usd_amount(df_trades, df_rates)
log_dataframe(df_out, 'df_out')

prints:

df_out
+--------+---+------+----------+----------+----+----------+
|currency| id|amount|trade_date| rate_date|rate|usd_amount|
+--------+---+------+----------+----------+----+----------+
|     EUR|  1|  1000|2025-01-09|2025-01-09|1.19|    1190.0|
|     CAD|  2|  1000|2025-01-09|2025-01-08|0.78|     780.0|
|     AUD|  3|  1000|2025-01-09|      NULL|NULL|      NULL|
|     HKD|  4|  1000|2025-01-09|      NULL|NULL|      NULL|
+--------+---+------+----------+----------+----+----------+

Solution

  • I think we can use conditional joining with trade_date / date_date and a window function to achieve the same goal:

    df_out = df_trades.alias("trades").join(
        df_rates.alias("rates"),
        on=[
            func.col("trades.currency")==func.col("rates.currency"),
            func.col("trades.trade_date")>=func.col("rates.rate_date"),
            func.date_add(func.col("trades.trade_date"), -7)<=func.col("rates.rate_date")
        ],
        how="left"
    ).withColumn(
        "rank", func.rank().over(Window.partitionBy("trades.id").orderBy(func.desc("rates.rate_date")))
    ).filter(
        func.col("rank") == 1
    ).select(
        "trades.id",
        "trades.amount",
        "trades.trade_date",
        "trades.currency",
        "rates.rate_date",
        "rates.rate",
        (func.col("trades.amount")*func.col("rates.rate")).alias("usd_amount")
    )