In my pyspark job, I have a huge data framework with more than 6,000 columns in the following format:
id_ a1 a2 a3 a4 a5 .... a6250
u1827s True False True False False .... False
...
Where the majority of the columns a1,a2,a3,...,a6250
are of binary type. I need to group this data by all these columns and aggregate the number of distinct ids for each combination, e.g.
df = df.groupby(list_of_cols).agg(F.countDistinct("id_"))
where list_of_cols = [a1,a2,...,a6250]
. When running this pyspark job, I am having java.lang.StackOverflowError
error. I am aware that I can increase the stack size (as per https://rangareddy.github.io/SparkStackOverflow/), however, I'd prefer a more elegant solution that would also enable a more convenient output.
I have two ideas what to do before grouping by:
Encode a combination of a1,a2,...,a6250 columns into a single binary column, such as a binary number with 6250 bits where the bit on position k
would encode True or False value for the column a_k
, e.g. in the example above the value would be 10100...0
(a1 is true, a2 is false, a3 is true, a4 is false, a5 is false,... a6250 is false).
Collect these values into a binary array, e.g. have 1 column like array(True,False,True,False,False,....,False).
Which way is better - to increase a stack size and deal with 6000+ columns, to use a single binary column, or an array of binary values?
I believe the first solution would be more practical to find the distinct count of all existing combinations. Here's some code I've tried to reproduce this solution given your data schema
from pyspark.sql import Row
import pyspark.sql.functions as F
data = [
("A", True, True, True), # Counted
("A", True, True, True), # Not Counted
("A", False, False, False), # Counted
("B", False, False, False), # Counted
("B", False, False, False), # Not Counted
("B", True, True, True), # Counted
]
columns = ["id", "a0", "a1", "a2"]
df = spark.createDataFrame(data, columns)
# To dynamically access the columns list on each executor
dynamic_column_list = spark.sparkContext.broadcast(df.columns)
def extract_pattern(row):
"""
Extract the information from the row according to the following values and append to pattern.
True => "1"
False => "0"
"""
pattern = ""
for column in dynamic_column_list.value:
bit = ""
if row[column] is True:
bit = "1"
elif row[column] is False:
bit = "0"
pattern += bit
return Row(id=row["id"], pattern=pattern)
extracted_pattern = df.rdd.map(lambda r: extract_pattern(r)).toDF()
counted_by_pattern = extracted_pattern.groupBy("pattern").agg(F.countDistinct("id"))
counted_by_pattern.show()
I tried to use the same function to get the distinct count of IDs associated with each pattern
counted_by_pattern = extracted_pattern.groupBy("pattern").agg(F.countDistinct("id"))
I haven't tried to replicate the 6K+ columns, but please share the results after you review and try this code.