I have the below dataframe created in PySpark code:
+---------------+-------------+---------------+------+
|TransactionDate|AccountNumber|TransactionType|Amount|
+---------------+-------------+---------------+------+
| 2023-01-01| 100| Credit| 1000|
| 2023-01-02| 100| Credit| 1500|
| 2023-01-03| 100| Debit| 1000|
| 2023-01-02| 200| Credit| 3500|
| 2023-01-03| 200| Debit| 2000|
| 2023-01-04| 200| Credit| 3500|
| 2023-01-13| 300| Credit| 4000|
| 2023-01-14| 300| Debit| 4500|
| 2023-01-15| 300| Credit| 5000|
+---------------+-------------+---------------+------+
I need to print another column as CurrentBalance.
Expected output:
+---------------+-------------+---------------+------+--------------+
|TransactionDate|AccountNumber|TransactionType|Amount|CurrentBalance|
+---------------+-------------+---------------+------+--------------+
| 2023-01-01| 100| Credit| 1000| 1000|
| 2023-01-02| 100| Credit| 1500| 2500|
| 2023-01-03| 100| Debit| 1000| 1500|
| 2023-01-02| 200| Credit| 3500| 3500|
| 2023-01-03| 200| Debit| 2000| 1500|
| 2023-01-04| 200| Credit| 3500| 5000|
| 2023-01-13| 300| Credit| 4000| 4000|
| 2023-01-14| 300| Debit| 4500| -500|
| 2023-01-15| 300| Credit| 5000| 1000|
+---------------+-------------+---------------+------+--------------+
I have tried with min date and passing the date in when condition to calculate the credit and debit, but it seems not working.
# Find minimum date in TransactionDate column, grouped by AccountNumber column
df_new.groupBy('AccountNumber').agg(f.min('TransactionDate').alias('min_date'))
You will need a window function for this. Window functions do calculations for every row in a partition (group). In this case, you need to sum all the values row-after-row.
Also, simple sum would not work, as you do not have negative numbers, so you must make them using input from column "TransactionType".
Example data:
from pyspark.sql import functions as F, Window as W
df = spark.createDataFrame(
[('2023-01-01', '100', 'Credit', 1000),
('2023-01-02', '100', 'Credit', 1500),
('2023-01-03', '100', 'Debit', 1000),
('2023-01-02', '200', 'Credit', 3500),
('2023-01-03', '200', 'Debit', 2000),
('2023-01-04', '200', 'Credit', 3500),
('2023-01-13', '300', 'Credit', 4000),
('2023-01-14', '300', 'Debit', 4500),
('2023-01-15', '300', 'Credit', 5000)],
['TransactionDate', 'AccountNumber', 'TransactionType', 'Amount'])
Script:
sign = F.when(F.col('TransactionType') == 'Debit', -1).otherwise(1)
amount = sign * F.col('amount')
window = W.partitionBy('AccountNumber').orderBy('TransactionDate')
df = df.withColumn('CurrentBalance', F.sum(amount).over(window))
df.show()
# +---------------+-------------+---------------+------+--------------+
# |TransactionDate|AccountNumber|TransactionType|Amount|CurrentBalance|
# +---------------+-------------+---------------+------+--------------+
# | 2023-01-01| 100| Credit| 1000| 1000|
# | 2023-01-02| 100| Credit| 1500| 2500|
# | 2023-01-03| 100| Debit| 1000| 1500|
# | 2023-01-02| 200| Credit| 3500| 3500|
# | 2023-01-03| 200| Debit| 2000| 1500|
# | 2023-01-04| 200| Credit| 3500| 5000|
# | 2023-01-13| 300| Credit| 4000| 4000|
# | 2023-01-14| 300| Debit| 4500| -500|
# | 2023-01-15| 300| Credit| 5000| 4500|
# +---------------+-------------+---------------+------+--------------+