pysparkapache-spark-sqlpyspark-pandaspyspark-schema

i want to obtain max value of a column depending on two other columns and for the forth column the value of the most repeated number


I've got this dataframe

df1 = spark.createDataFrame([
    ('c', 'd', 3.0, 4),
    ('c', 'd', 7.3, 8),
    ('c', 'd', 7.3, 2),
    ('c', 'd', 7.3, 8),
    ('e', 'f', 6.0, 3),
    ('e', 'f', 6.0, 8),
    ('e', 'f', 6.0, 3),
    ('c', 'j', 4.2, 3),
    ('c', 'j', 4.3, 9),
], ['a', 'b', 'c', 'd'])
df1.show()
+---+---+---+---+
|  a|  b|  c|  d|
+---+---+---+---+
|  c|  d|3.0|  4|
|  c|  d|7.3|  8|
|  c|  d|7.3|  2|
|  c|  d|7.3|  8|
|  e|  f|6.0|  3|
|  e|  f|6.0|  8|
|  e|  f|6.0|  3|
|  c|  j|4.2|  3|
|  c|  j|4.3|  9|
+---+---+---+---+

i did this to get the max of c of the couple a and b

df2 = df1.groupBy('a', 'b').agg(F.max('c').alias('c_max')).select(
        F.col('a'),
        F.col('b'),
        F.col('c_max').alias('c')
    )
df2.show()
+---+---+---+
|  a|  b|  c|
+---+---+---+
|  e|  f|6.0|
|  c|  d|7.3|
|  c|  j|4.3|
+---+---+---+

but now i need to get the values of d that should be

+---+---+---+---+
|  a|  b|  c|  d|
+---+---+---+---+
|  c|  d|7.3|  8|
|  e|  f|6.0|  3|
|  c|  j|4.3|  9|
+---+---+---+---+

i tried to do an inner join between df1 and df2 but that didn't work:

condition = [df1.a ==  df2.a, df1.b ==  df2.b, df1.c ==  df2.c]
df3 = df1.join(df2,condition,"inner")
df3.show()
+---+---+---+---+---+---+---+
|  a|  b|  c|  d|  a|  b|  c|
+---+---+---+---+---+---+---+
|  c|  d|7.3|  8|  c|  d|7.3|
|  c|  d|7.3|  8|  c|  d|7.3|
|  c|  d|7.3|  2|  c|  d|7.3|
|  e|  f|6.0|  3|  e|  f|6.0|
|  e|  f|6.0|  8|  e|  f|6.0|
|  e|  f|6.0|  3|  e|  f|6.0|
|  c|  j|4.3|  9|  c|  j|4.3|
+---+---+---+---+---+---+---+

i'm a beginner in pyspark, so please i need a little help to figure this out


Solution

  • You can "zip" d and count of d and aggregate as usual to keep the frequency

    df3 = (df1
        .groupBy('a', 'b', 'd')
        .agg(F.count('*').alias('d_count'))
        .groupBy('a', 'b')
        .agg(F.max(F.array('d_count', 'd')).alias('d_freq'))
        .select('a', 'b', F.col('d_freq')[1].alias('d'))
    )
    
    +---+---+---+
    |  a|  b|  d|
    +---+---+---+
    |  c|  d|  8|
    |  c|  j|  9|
    |  e|  f|  3|
    +---+---+---+
    

    Now join both your df2 and this new df3 will give your desired output.

    df2.join(df3, on=['a', 'b']).show()
    +---+---+---+---+
    |  a|  b|  c|  d|
    +---+---+---+---+
    |  c|  d|7.3|  8|
    |  c|  j|4.3|  9|
    |  e|  f|6.0|  3|
    +---+---+---+---+