i want to partition/group rows for every group of size <= limit
for example, if i have:
+--------+----------+
| id| size|
+--------+----------+
| 1| 3|
| 2| 6|
| 3| 8|
| 4| 5|
| 5| 7|
| 6| 7|
+--------+----------+
and i want to group rows by every size <=10, result would be:
+--------+----------+----------+
| id| size| group|
+--------+----------+----------+
| 1| 3| 0|
| 2| 6| 0|
| 3| 8| 1|
| 4| 5| 2|
| 5| 7| 3|
| 6| 7| 4|
+--------+----------+----------+
another example, by every size <=13,
+--------+----------+----------+
| id| size| group|
+--------+----------+----------+
| 1| 3| 0|
| 2| 6| 0|
| 3| 8| 1|
| 4| 5| 1|
| 5| 7| 2|
| 6| 7| 3|
+--------+----------+----------+
not even quite sure where to start with, have looked into window function, reduce function, user define aggregate function or adding addition columns (e.g. accumulate sum etc)..
the original task was to group request payload so that under a size limit they can be grouped into a single request.
Here's an example. Since exact correctness is not required for you application, which implies we can have approximate correctness.
First we group the rows into suitable sizes. On those groups, we use pandas_udf grouped map to find subgroups which give you the optimum number of rows with < payload_limit.
Here's a possible example.
import math
from pyspark.sql.functions import avg, floor, rand, pandas_udf, PandasUDFType
from pyspark.sql.functions import col, sum, row_number, monotonically_increasing_id, count
from pyspark.sql import SparkSession
from pyspark.sql.types import *
spark = SparkSession.builder \
.appName("Example") \
.getOrCreate()
data = [
(1, 3),
(2, 6),
(3, 8),
(4, 5),
(5, 7),
(6, 7),
(11, 3),
(12, 6),
(13, 8),
(14, 5),
(15, 7),
(16, 7),
(21, 3),
(22, 6),
(23, 8),
(24, 5),
(25, 7),
(26, 7)
]
df = spark.createDataFrame(data, ["id", "size"])
# Find the average size
avg_size = df.select(avg("size")).collect()[0][0]
payload_limit = 20
rows_per_group = math.floor(payload_limit / avg_size)
print(f"{avg_size=}")
print(f"{rows_per_group=}")
print(f"{df.count()=}")
nums_of_group = math.ceil(df.count() / rows_per_group)
print(f"{nums_of_group=}")
df = df.withColumn("random_group_id", floor(rand() * nums_of_group))
distinct_group_ids = df.select(col("random_group_id")).distinct()
distinct_group_ids.show(n=100, truncate=False)
print(f"{distinct_group_ids.count()}")
grouped_counts = df.groupby(col("random_group_id")).agg(count("*"))
grouped_counts.show(n=100, truncate=False)
df.show(n=100, truncate=False)
result_schema = StructType([
StructField("id", IntegerType()),
StructField("size", IntegerType()),
StructField("random_group_id", IntegerType()),
StructField("sub_group", IntegerType()),
])
@pandas_udf(result_schema, PandasUDFType.GROUPED_MAP)
def group_by_limit(pdf):
limit = payload_limit
group_col = "sub_group"
print("before")
print(pdf)
# Calculate the cumulative sum of sizes within each random group
pdf["cum_size"] = pdf.groupby("random_group_id")["size"].cumsum()
# Assign group numbers based on the cumulative sum and limit
pdf[group_col] = (pdf["cum_size"]) // limit
# Drop the cumulative sum column
pdf = pdf.drop("cum_size", axis=1)
print("after")
print(pdf)
return pdf
# Apply the pandas UDF to the DataFrame
grouped_df = df.groupby("random_group_id").apply(group_by_limit)
grouped_df.show()
## Verify correctness of our algorithm.
result_grouped = grouped_df.groupBy("random_group_id", "sub_group").agg(sum("size"))
result_grouped.orderBy("random_group_id", "sub_group").show(n=100, truncate=False)
Output :
avg_size=6.0
rows_per_group=3
df.count()=18
nums_of_group=6
+---------------+
|random_group_id|
+---------------+
|5 |
|3 |
|1 |
|2 |
|4 |
+---------------+
5
+---------------+--------+
|random_group_id|count(1)|
+---------------+--------+
|5 |3 |
|3 |4 |
|1 |5 |
|2 |2 |
|4 |4 |
+---------------+--------+
+---+----+---------------+
|id |size|random_group_id|
+---+----+---------------+
|1 |3 |5 |
|2 |6 |3 |
|3 |8 |1 |
|4 |5 |5 |
|5 |7 |1 |
|6 |7 |1 |
|11 |3 |3 |
|12 |6 |3 |
|13 |8 |5 |
|14 |5 |2 |
|15 |7 |4 |
|16 |7 |1 |
|21 |3 |4 |
|22 |6 |2 |
|23 |8 |3 |
|24 |5 |1 |
|25 |7 |4 |
|26 |7 |4 |
+---+----+---------------+
+---+----+---------------+---------+
| id|size|random_group_id|sub_group|
+---+----+---------------+---------+
| 3| 8| 1| 0|
| 5| 7| 1| 0|
| 6| 7| 1| 1|
| 16| 7| 1| 1|
| 24| 5| 1| 1|
| 14| 5| 2| 0|
| 22| 6| 2| 0|
| 2| 6| 3| 0|
| 11| 3| 3| 0|
| 12| 6| 3| 0|
| 23| 8| 3| 1|
| 15| 7| 4| 0|
| 21| 3| 4| 0|
| 25| 7| 4| 0|
| 26| 7| 4| 1|
| 1| 3| 5| 0|
| 4| 5| 5| 0|
| 13| 8| 5| 0|
+---+----+---------------+---------+
+---------------+---------+---------+
|random_group_id|sub_group|sum(size)|
+---------------+---------+---------+
|1 |0 |15 |
|1 |1 |19 |
|2 |0 |11 |
|3 |0 |15 |
|3 |1 |8 |
|4 |0 |17 |
|4 |1 |7 |
|5 |0 |16 |
+---------------+---------+---------+