apache-sparkpysparkdatabricksdatabricks-sqlscala-spark

Is there a way to partition/group by data where sum of column values per each group is under a limit?


i want to partition/group rows for every group of size <= limit

for example, if i have:

+--------+----------+
|      id|      size|
+--------+----------+
|       1|         3|
|       2|         6|
|       3|         8|
|       4|         5|
|       5|         7|
|       6|         7|
+--------+----------+

and i want to group rows by every size <=10, result would be:

+--------+----------+----------+
|      id|      size|     group|
+--------+----------+----------+
|       1|         3|         0|
|       2|         6|         0|
|       3|         8|         1|
|       4|         5|         2|
|       5|         7|         3|
|       6|         7|         4|
+--------+----------+----------+

another example, by every size <=13,

+--------+----------+----------+
|      id|      size|     group|
+--------+----------+----------+
|       1|         3|         0|
|       2|         6|         0|
|       3|         8|         1|
|       4|         5|         1|
|       5|         7|         2|
|       6|         7|         3|
+--------+----------+----------+

not even quite sure where to start with, have looked into window function, reduce function, user define aggregate function or adding addition columns (e.g. accumulate sum etc)..

the original task was to group request payload so that under a size limit they can be grouped into a single request.


Solution

  • Here's an example. Since exact correctness is not required for you application, which implies we can have approximate correctness.

    First we group the rows into suitable sizes. On those groups, we use pandas_udf grouped map to find subgroups which give you the optimum number of rows with < payload_limit.

    Here's a possible example.

    import math
    from pyspark.sql.functions import avg, floor, rand, pandas_udf, PandasUDFType
    from pyspark.sql.functions import col, sum, row_number, monotonically_increasing_id, count
    from pyspark.sql import SparkSession
    from pyspark.sql.types import *
    
    spark = SparkSession.builder \
        .appName("Example") \
        .getOrCreate()
    
    data = [
        (1, 3),
        (2, 6),
        (3, 8),
        (4, 5),
        (5, 7),
        (6, 7),
        (11, 3),
        (12, 6),
        (13, 8),
        (14, 5),
        (15, 7),
        (16, 7),
        (21, 3),
        (22, 6),
        (23, 8),
        (24, 5),
        (25, 7),
        (26, 7)
    
    ]
    df = spark.createDataFrame(data, ["id", "size"])
    
    # Find the average size
    avg_size = df.select(avg("size")).collect()[0][0]
    
    payload_limit = 20
    rows_per_group = math.floor(payload_limit / avg_size)
    print(f"{avg_size=}")
    print(f"{rows_per_group=}")
    print(f"{df.count()=}")
    nums_of_group = math.ceil(df.count() / rows_per_group)
    print(f"{nums_of_group=}")
    
    df = df.withColumn("random_group_id", floor(rand() * nums_of_group))
    distinct_group_ids = df.select(col("random_group_id")).distinct()
    distinct_group_ids.show(n=100, truncate=False)
    print(f"{distinct_group_ids.count()}")
    
    grouped_counts = df.groupby(col("random_group_id")).agg(count("*"))
    grouped_counts.show(n=100, truncate=False)
    
    df.show(n=100, truncate=False)
    
    result_schema = StructType([
        StructField("id", IntegerType()),
        StructField("size", IntegerType()),
        StructField("random_group_id", IntegerType()),
        StructField("sub_group", IntegerType()),
    ])
    
    
    @pandas_udf(result_schema, PandasUDFType.GROUPED_MAP)
    def group_by_limit(pdf):
        limit = payload_limit
        group_col = "sub_group"
    
        print("before")
        print(pdf)
    
        # Calculate the cumulative sum of sizes within each random group
        pdf["cum_size"] = pdf.groupby("random_group_id")["size"].cumsum()
    
        # Assign group numbers based on the cumulative sum and limit
        pdf[group_col] = (pdf["cum_size"]) // limit
    
    
        # Drop the cumulative sum column
        pdf = pdf.drop("cum_size", axis=1)
    
        print("after")
        print(pdf)
    
        return pdf
    
    
    # Apply the pandas UDF to the DataFrame
    grouped_df = df.groupby("random_group_id").apply(group_by_limit)
    
    
    grouped_df.show()
    
    ## Verify correctness of our algorithm.
    result_grouped = grouped_df.groupBy("random_group_id", "sub_group").agg(sum("size"))
    result_grouped.orderBy("random_group_id", "sub_group").show(n=100, truncate=False)
    

    Output :

    avg_size=6.0
    rows_per_group=3
    df.count()=18
    nums_of_group=6
    +---------------+
    |random_group_id|
    +---------------+
    |5              |
    |3              |
    |1              |
    |2              |
    |4              |
    +---------------+
    
    5
    +---------------+--------+
    |random_group_id|count(1)|
    +---------------+--------+
    |5              |3       |
    |3              |4       |
    |1              |5       |
    |2              |2       |
    |4              |4       |
    +---------------+--------+
    
    +---+----+---------------+
    |id |size|random_group_id|
    +---+----+---------------+
    |1  |3   |5              |
    |2  |6   |3              |
    |3  |8   |1              |
    |4  |5   |5              |
    |5  |7   |1              |
    |6  |7   |1              |
    |11 |3   |3              |
    |12 |6   |3              |
    |13 |8   |5              |
    |14 |5   |2              |
    |15 |7   |4              |
    |16 |7   |1              |
    |21 |3   |4              |
    |22 |6   |2              |
    |23 |8   |3              |
    |24 |5   |1              |
    |25 |7   |4              |
    |26 |7   |4              |
    +---+----+---------------+
    
    +---+----+---------------+---------+
    | id|size|random_group_id|sub_group|
    +---+----+---------------+---------+
    |  3|   8|              1|        0|
    |  5|   7|              1|        0|
    |  6|   7|              1|        1|
    | 16|   7|              1|        1|
    | 24|   5|              1|        1|
    | 14|   5|              2|        0|
    | 22|   6|              2|        0|
    |  2|   6|              3|        0|
    | 11|   3|              3|        0|
    | 12|   6|              3|        0|
    | 23|   8|              3|        1|
    | 15|   7|              4|        0|
    | 21|   3|              4|        0|
    | 25|   7|              4|        0|
    | 26|   7|              4|        1|
    |  1|   3|              5|        0|
    |  4|   5|              5|        0|
    | 13|   8|              5|        0|
    +---+----+---------------+---------+
    
    +---------------+---------+---------+
    |random_group_id|sub_group|sum(size)|
    +---------------+---------+---------+
    |1              |0        |15       |
    |1              |1        |19       |
    |2              |0        |11       |
    |3              |0        |15       |
    |3              |1        |8        |
    |4              |0        |17       |
    |4              |1        |7        |
    |5              |0        |16       |
    +---------------+---------+---------+