pysparkaggregate

sum of case when in pyspark


I am trying convert hql script into pyspark. I am struggling how to achieve sum of case when statements in aggregation after groupby clause. eg.

dataframe1 = dataframe0.groupby(col0).agg(
            SUM(f.when((col1 == 'ABC' | col2 == 'XYZ'), 1).otherwise(0)))

Is it possible in pyspark? I am getting error while executing such statement. Thanks


Solution

  • You can use withColumn to create a column with the values you want to to be summed, then aggregate on that. For example:

    from pyspark.sql import functions as F, types as T
    
    schema = T.StructType([
        T.StructField('key', T.IntegerType(), True),
        T.StructField('col1', T.StringType(), True),
        T.StructField('col2', T.StringType(), True)
    ])
    
    data = [
        (1, 'ABC', 'DEF'),
        (1, 'DEF', 'XYZ'),
        (1, 'DEF', 'GHI')
    ]
    
    rdd = sc.parallelize(data)
    df = sqlContext.createDataFrame(rdd, schema)
    
    
    
    result = df.withColumn('value', F.when((df.col1 == 'ABC') | (df.col2 == 'XYZ'), 1).otherwise(0)) \
               .groupBy('key') \
                  .agg(F.sum('value').alias('sum'))
    
    result.show(100, False)
    

    Which prints out this result:

    +---+---+
    |key|sum|
    +---+---+
    |1  |2  |
    +---+---+