I have two DataFrames:
df_rates df_trades
(rate from currency -> USD)
+--------+----------+----+ +---+--------+------+----------+
|currency| rate_date|rate| | id|currency|amount|trade_date|
+--------+----------+----+ +---+--------+------+----------+
| EUR|2025-01-09|1.19| | 1| EUR| 1000|2025-01-09| # exact rate available
| EUR|2025-01-08|1.18| | 2| CAD| 1000|2025-01-09| # 1 day prior rate available
| CAD|2025-01-08|0.78| | 3| AUD| 1000|2025-01-09| # no applicable rate available
| CAD|2025-01-07|0.77| | 4| HKD| 1000|2025-01-09| # no rate available at all
| AUD|2025-02-09|1.39| +---+--------+------+----------+
| AUD|2025-02-08|1.38|
+--------+----------+----+
For every trade I need to calculate usd_amount by applying appropriate exchange rate. The way to pick rate is:
rate
on the trade_date
trade_date
then go back up to 7 daysIf no rate is found in this way then usd_amount = null
I have following code that works. But I'm not sure if it'll scale. Specifically for the case of trade_id = 3
(when rates are available but none are in correct date range), because in practice the rates table has 1000s of rates (going back 5-7 years). Part of code marked This PART
below.
from pyspark.sql import SparkSession, Window
from pyspark.sql.functions import col, row_number, date_diff, when
def log_dataframe(df, msg):
print(msg)
df.show()
def calc_usd_amount(df_trades, df_rates):
df = df_trades.join(df_rates, how='left_outer', on='currency').withColumn('date_diff', date_diff('trade_date', 'rate_date'))
date_diff_no_good = (col('date_diff') < 0) | (col('date_diff') > 7)
# This PART
df = (
df.withColumns({
'rate_date': when(date_diff_no_good, None).otherwise(col('rate_date')),
'rate': when(date_diff_no_good, None).otherwise(col('rate')),
})
.drop_duplicates(['id', 'rate_date', 'rate'])
)
w_spec = row_number().over(Window.partitionBy(col('id'), col('currency')).orderBy(col('rate_date').desc()))
df = (
df.filter('rate_date IS NULL OR (rate_date <= trade_date AND rate_date > (trade_date - 7))')
.withColumn('rate_row_num', w_spec).filter('rate_row_num == 1')
.withColumn('usd_amount', col('rate') * col('amount'))
)
return df.drop('date_diff', 'rate_row_num')
from pyspark.sql import SparkSession
from datetime import date
spark = SparkSession.builder.getOrCreate()
dt = date.fromisoformat
df_trades = spark.createDataFrame(
data = [
(1, 'EUR', 1000, dt('2025-01-09')), # trade date rate available
(2, 'CAD', 1000, dt('2025-01-09')), # trade date -1d, rate available
(3, 'AUD', 1000, dt('2025-01-09')), # no applicable rate available
(4, 'HKD', 1000, dt('2025-01-09')), # no rate available at all
],
schema=['id', 'currency', 'amount', 'trade_date'],
)
df_rates = spark.createDataFrame(
data = [
('EUR', dt('2025-01-09'), 1.19), # trade date rate available
('EUR', dt('2025-01-08'), 1.18),
('CAD', dt('2025-01-08'), 0.78), # trade date -1d, rate available
('CAD', dt('2025-01-07'), 0.77),
('AUD', dt('2025-02-09'), 1.39), # no applicable rate available
('AUD', dt('2025-02-08'), 1.38),
],
schema=['currency', 'rate_date', 'rate']
)
df_out = calc_usd_amount(df_trades, df_rates)
log_dataframe(df_out, 'df_out')
prints:
df_out
+--------+---+------+----------+----------+----+----------+
|currency| id|amount|trade_date| rate_date|rate|usd_amount|
+--------+---+------+----------+----------+----+----------+
| EUR| 1| 1000|2025-01-09|2025-01-09|1.19| 1190.0|
| CAD| 2| 1000|2025-01-09|2025-01-08|0.78| 780.0|
| AUD| 3| 1000|2025-01-09| NULL|NULL| NULL|
| HKD| 4| 1000|2025-01-09| NULL|NULL| NULL|
+--------+---+------+----------+----------+----+----------+
I think we can use conditional joining with trade_date / date_date and a window function to achieve the same goal:
df_out = df_trades.alias("trades").join(
df_rates.alias("rates"),
on=[
func.col("trades.currency")==func.col("rates.currency"),
func.col("trades.trade_date")>=func.col("rates.rate_date"),
func.date_add(func.col("trades.trade_date"), -7)<=func.col("rates.rate_date")
],
how="left"
).withColumn(
"rank", func.rank().over(Window.partitionBy("trades.id").orderBy(func.desc("rates.rate_date")))
).filter(
func.col("rank") == 1
).select(
"trades.id",
"trades.amount",
"trades.trade_date",
"trades.currency",
"rates.rate_date",
"rates.rate",
(func.col("trades.amount")*func.col("rates.rate")).alias("usd_amount")
)