Inspired by this answer, I want to find the row-wise minimum between several date columns, and return the column name.
I'm getting unexpected results when a row contains NULLs, which I thought least
excluded, specifically rows 2-5 in this toy example:
import datetime as dt
from pyspark.sql import Row
from pyspark.sql.types import StructField, StructType, DateType
schema = StructType([
StructField("date1", DateType(), True),
StructField("date2", DateType(), True),
StructField("date3", DateType(), True)
])
row1 = Row(dt.date(2024, 1, 1), dt.date(2024, 1, 2), dt.date(2024, 1, 3))
row2 = Row(None, None, dt.date(2024, 1, 3))
row3 = Row(None, dt.date(2024, 1, 1), dt.date(2024, 1, 2))
row4 = Row(None, None, None)
row5 = Row(dt.date(2024, 1, 1), None, None)
df = spark.createDataFrame([row1, row2, row3, row4, row5], schema)
def row_min(*cols):
cols_ = [F.struct(F.col(c).alias("value"), F.lit(c).alias("col")) for c in cols]
return F.least(*cols_)
df.withColumn("output", row_min('date1', 'date2', 'date3').col).show()
returns
+----------+----------+----------+------+
| date1| date2| date3|output|
+----------+----------+----------+------+
|2024-01-01|2024-01-02|2024-01-03| date1|
| NULL| NULL|2024-01-03| date1|
| NULL|2024-01-01|2024-01-02| date1|
| NULL| NULL| NULL| date1|
|2024-01-01| NULL| NULL| date2|
+----------+----------+----------+------+
but the desired output is:
+----------+----------+----------+------+
| date1| date2| date3|output|
+----------+----------+----------+------+
|2024-01-01|2024-01-02|2024-01-03| date1|
| NULL| NULL|2024-01-03| date3|
| NULL|2024-01-01|2024-01-02| date2|
| NULL| NULL| NULL| NULL|
|2024-01-01| NULL| NULL| date1|
+----------+----------+----------+------+
You are comparing struct<value:date,col:string>
in which the value field might be NULLs. least
function ignore NULL only if the whole struct
is NULL, not one of the fields.
For Spark SQL sorting, by default NULL values appear first in ascending order, and last in descending order. So one quick fix is to negate the date (for example using a date interval (F.expr("date'1970'")-F.col(c)).alias("value")
, and then apply the greatest
function.
def row_min(*cols):
cols_ = [F.struct((F.expr("date'1970'")-F.col(c)).alias("value"), F.lit(c).alias("col")) for c in cols ]
return F.greatest(*cols_)
row_least = row_min('date1', 'date2', 'date3')
df.withColumn("output", F.when(F.isnull(row_least.value),None).otherwise(row_least.col)).show()
+----------+----------+----------+------+
| date1| date2| date3|output|
+----------+----------+----------+------+
|2024-01-01|2024-01-02|2024-01-03| date1|
| null| null|2024-01-03| date3|
| null|2024-01-01|2024-01-02| date2|
| null| null| null| null|
|2024-01-01| null| null| date1|
+----------+----------+----------+------+