pythondataframepysparkoptimization

Optimization of PySpark code to do comparisons of rows


I want to iteratively compare 2 sets of rows in a PySpark dataframe, and find the common values in another column. For example, I have the dataframe (df) below.

Column1  Column2
abc      111
def      666
def      111
tyu      777
abc      777
def      222
tyu      333
ewq      888

The output I want is

abc,def,CommonRow  <-- because of 111
abc,ewq,NoCommonRow
abc,tyu,CommonRow  <-- because of 777
def,ewq,NoCommonRow
def,tyu,NoCommonRow
ewq,tyu,NoCommonRow

The PySpark code that I'm currently using to do this is

# "value_list" contains the unique list of values in Column 1
index = 0
for col1 in value_list:
    index += 1
    df_col1 = df.filter(df.Column1 == col1)
    for col2 in value_list[index:]:
        df_col2 = df.filter(df.Column1 == col2)

        df_join = df_col1.join(df_col2, on=(df_col1.Column2 == df_col2.Column2), how="inner")
        if df_join.limit(1).count() == 0:   # No common row
            print(col1,col2,"NoCommonRow")
        else:
            print(col1,col2,"CommonRow")

However, I found that this takes a very long time to run (df has millions of rows). Is there anyway to optimize it to run faster, or is there a better way to do the comparisons?


Solution

  • You can do this without loops using self join as follows:

    from pyspark.sql import SparkSession
    from pyspark.sql import functions as F
    
    
    spark = SparkSession.builder.getOrCreate()
    
    data = [("abc", 111), ("def", 666), ("def", 111), ("tyu", 777),
            ("abc", 777), ("def", 222), ("tyu", 333), ("ewq", 888)]
    df = spark.createDataFrame(data, ["Column1", "Column2"])
    
    result_df = (
        df.alias("a")
        .join(df.alias("b"), F.col("a.Column1") < F.col("b.Column1"), "left") # this is more efficient than the "!=" approach used in other answer because using "<" creates less rows after join (almost half) 
        .join(df.alias("c"), (F.col("a.Column2") == F.col("c.Column2")) & (F.col("b.Column2") == F.col("c.Column2")), "left")
        .where(F.col("a.Column1").isNotNull() & F.col("b.Column1").isNotNull())
        .groupBy("a.Column1", "b.Column1")
        .agg(
            F.when(F.count("c.Column2") > 0, "CommonRow").otherwise("NoCommonRow").alias("CommonStatus")
        )
        .orderBy("a.Column1", "b.Column1") # remove this if not required
    )
    result_df.show(truncate=False)
    
    # Output:
    # +-------+-------+------------+
    # |Column1|Column1|CommonStatus|
    # +-------+-------+------------+
    # |abc    |def    |CommonRow   |
    # |abc    |ewq    |NoCommonRow |
    # |abc    |tyu    |CommonRow   |
    # |def    |ewq    |NoCommonRow |
    # |def    |tyu    |NoCommonRow |
    # |ewq    |tyu    |NoCommonRow |
    # +-------+-------+------------+
    

    Also, please avoid the orderBy when you run this on your million-row df - use it only if necessary.