pythonarraysapache-sparkpysparkcombinations

Pair combinations of array column values in PySpark


Similar to this question (Scala), but I need combinations in PySpark (pair combinations of array column).

Example input:

df = spark.createDataFrame(
    [([0, 1],),
     ([2, 3, 4],),
     ([5, 6, 7, 8],)],
    ['array_col'])

Expected output:

+------------+------------------------------------------------+
|array_col   |out                                             |
+------------+------------------------------------------------+
|[0, 1]      |[[0, 1]]                                        |
|[2, 3, 4]   |[[2, 3], [2, 4], [3, 4]]                        |
|[5, 6, 7, 8]|[[5, 6], [5, 7], [5, 8], [6, 7], [6, 8], [7, 8]]|
+------------+------------------------------------------------+

Solution

  • pandas_udf is an efficient and concise approach in PySpark.

    from pyspark.sql import functions as F
    import pandas as pd
    from itertools import combinations
    
    @F.pandas_udf('array<array<int>>')
    def pudf(c: pd.Series) -> pd.Series:
        return c.apply(lambda x: list(combinations(x, 2)))
    
    
    df = df.withColumn('out', pudf('array_col'))
    df.show(truncate=0)
    # +------------+------------------------------------------------+
    # |array_col   |out                                             |
    # +------------+------------------------------------------------+
    # |[0, 1]      |[[0, 1]]                                        |
    # |[2, 3, 4]   |[[2, 3], [2, 4], [3, 4]]                        |
    # |[5, 6, 7, 8]|[[5, 6], [5, 7], [5, 8], [6, 7], [6, 8], [7, 8]]|
    # +------------+------------------------------------------------+
    

    Note: in some systems, instead of 'array<array<int>>' you may need to provide types from pyspark.sql.types, e.g. ArrayType(ArrayType(IntegerType())).