apache-sparkpyspark

Count entries for all possible categories


I have data that looks like this

id | id_2  | ... | label |
-- | ----- | --- | ----- |
1  | x     | ... | a     |
1  | x     | ... | a     |
1  | x     | ... | b     |
1  | y     | ... | a     |
1  | z     | ... | c     |
1  | z     | ... | b     |
1  | z     | ... | c     |
2  | x     | ... | a     |
2  | x     | ... | a     |
2  | x     | ... | b     |
2  | x     | ... | b     |
2  | x     | ... | b     |
2  | x     | ... | b     |

I need to create an output data-frame that has for each id and id_2 and for each possible label (I know the possible values in advance) a count, with the addition that if a label does not exist for an id and id_2, count should be 0

id | id_2 | label | count |
-- | ---- | ----- | ----- |
1  | x    | a     | 2     |
1  | x    | b     | 1     |
1  | x    | c     | 0     |
1  | y    | a     | 1     |
1  | y    | b     | 0     |
1  | y    | c     | 0     |
1  | z    | a     | 0     |
1  | z    | b     | 1     |
1  | z    | c     | 2     |
2  | x    | a     | 2     |
2  | x    | b     | 4     |
2  | x    | c     | 0     |

If I do this with a "simple" groupBy(id, id_2, label).agg(count), I do not get the "zero counts", obviously.

I would be very thankful for ideas!


Solution

  • You could create the cartesian product of ids and labels and then left join the counts:

    df = spark.createDataFrame(
        data=[(1,2, 'a'),(1,2,'b'),(1,3,'a'),(2,4, 'c')],
        schema=["id1", "id2", "label"])
    counts = df.groupby(['id1', 'id2', 'label']).count()
    labels = df.select(['label']).distinct()
    ids = df.select(['id1', 'id2']).distinct()
    all = ids.join(labels, how='outer')
    final = all.join(counts, on=['id1', 'id2', 'label'], how='left').na.fill(0)
    final.show()