The code snippet below creates the column 'rank' with a condition. I want to perform the rank based on a subset of the partition, hence I use a when clause and set category=='Y' and then execute the rank. However, I did not expect the result below. Where I expected rank=1 it is in fact rank=2.
How can I achieve to do a rank on a subset of a partition?
import pyspark.sql.functions as F
from pyspark.sql.window import Window
from pyspark.sql import Row
data = [
Row(id=1, code=14, category='N'),
Row(id=1, code=20, category='Y'),
Row(id=1, code=19, category='Y'),
Row(id=1, code=22, category='Y'),
Row(id=1, code=15, category='Y'),
]
ps_df = spark.createDataFrame(data)
window = Window.partitionBy('id').orderBy('code')
ps_df = ps_df.withColumn('rank', F.when(col('category')=='Y', F.rank().over(window)))
ps_df.show()
+---+----+--------+----+
| id|code|category|rank|
+---+----+--------+----+
| 1| 14| N|NULL|
| 1| 15| Y| 2|
| 1| 19| Y| 3|
| 1| 20| Y| 4|
| 1| 22| Y| 5|
+---+----+--------+----+
I think we can use a alternative way, which don't need to use rank()
, to achieve the same goal:
ps_df = ps_df.withColumn(
"flag", func.when(func.col("category")=="Y", func.lit(1)).otherwise(func.lit(0))
).withColumn(
"cumsum", func.sum("flag").over(Window.partitionBy("id").orderBy("code"))
).withColumn(
"rank", func.when(func.col("category")=="Y", func.col("cumsum")).otherwise(func.lit(None))
).select(
"id", "code", "category", "rank"
)
First, you raise a value equal to 1 flag if this is the partition or group that you want to calculate. Then use a sum()
with window function to do the cumulative sum of that partition to perform the ranking.