pysparkstack-overflow

Collapsing many binary columns into a single column in pyspark


In my pyspark job, I have a huge data framework with more than 6,000 columns in the following format:

id_        a1   a2     a3   a4     a5   .... a6250
u1827s    True False  True False False .... False
...

Where the majority of the columns a1,a2,a3,...,a6250 are of binary type. I need to group this data by all these columns and aggregate the number of distinct ids for each combination, e.g.

df = df.groupby(list_of_cols).agg(F.countDistinct("id_"))

where list_of_cols = [a1,a2,...,a6250]. When running this pyspark job, I am having java.lang.StackOverflowError error. I am aware that I can increase the stack size (as per https://rangareddy.github.io/SparkStackOverflow/), however, I'd prefer a more elegant solution that would also enable a more convenient output.

I have two ideas what to do before grouping by:

  1. Encode a combination of a1,a2,...,a6250 columns into a single binary column, such as a binary number with 6250 bits where the bit on position k would encode True or False value for the column a_k, e.g. in the example above the value would be 10100...0 (a1 is true, a2 is false, a3 is true, a4 is false, a5 is false,... a6250 is false).

  2. Collect these values into a binary array, e.g. have 1 column like array(True,False,True,False,False,....,False).

Which way is better - to increase a stack size and deal with 6000+ columns, to use a single binary column, or an array of binary values?


Solution

  • I believe the first solution would be more practical to find the distinct count of all existing combinations. Here's some code I've tried to reproduce this solution given your data schema

    from pyspark.sql import Row
    import pyspark.sql.functions as F
    
    data = [
        ("A", True, True, True),  # Counted
        ("A", True, True, True),  # Not Counted
        ("A", False, False, False),  # Counted
        ("B", False, False, False),  # Counted
        ("B", False, False, False),  # Not Counted
        ("B", True, True, True),  # Counted
    ]
    
    columns = ["id", "a0", "a1", "a2"]
    df = spark.createDataFrame(data, columns)
    
    # To dynamically access the columns list on each executor
    dynamic_column_list = spark.sparkContext.broadcast(df.columns)
    
    
    def extract_pattern(row):
        """
        Extract the information from the row according to the following values and append to pattern.
    
        True => "1"
        False => "0"
        """
        pattern = ""
        for column in dynamic_column_list.value:
            bit = ""
            if row[column] is True:
                bit = "1"
            elif row[column] is False:
                bit = "0"
            pattern += bit
        return Row(id=row["id"], pattern=pattern)
    
    
    extracted_pattern = df.rdd.map(lambda r: extract_pattern(r)).toDF()
    counted_by_pattern = extracted_pattern.groupBy("pattern").agg(F.countDistinct("id"))
    counted_by_pattern.show()
    

    I tried to use the same function to get the distinct count of IDs associated with each pattern

    counted_by_pattern = extracted_pattern.groupBy("pattern").agg(F.countDistinct("id"))
    

    I haven't tried to replicate the 6K+ columns, but please share the results after you review and try this code.