I want to iteratively compare 2 sets of rows in a PySpark dataframe, and find the common values in another column. For example, I have the dataframe (df
) below.
Column1 Column2
abc 111
def 666
def 111
tyu 777
abc 777
def 222
tyu 333
ewq 888
The output I want is
abc,def,CommonRow <-- because of 111
abc,ewq,NoCommonRow
abc,tyu,CommonRow <-- because of 777
def,ewq,NoCommonRow
def,tyu,NoCommonRow
ewq,tyu,NoCommonRow
The PySpark code that I'm currently using to do this is
# "value_list" contains the unique list of values in Column 1
index = 0
for col1 in value_list:
index += 1
df_col1 = df.filter(df.Column1 == col1)
for col2 in value_list[index:]:
df_col2 = df.filter(df.Column1 == col2)
df_join = df_col1.join(df_col2, on=(df_col1.Column2 == df_col2.Column2), how="inner")
if df_join.limit(1).count() == 0: # No common row
print(col1,col2,"NoCommonRow")
else:
print(col1,col2,"CommonRow")
However, I found that this takes a very long time to run (df
has millions of rows). Is there anyway to optimize it to run faster, or is there a better way to do the comparisons?
You can do this without loops using self join as follows:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
spark = SparkSession.builder.getOrCreate()
data = [("abc", 111), ("def", 666), ("def", 111), ("tyu", 777),
("abc", 777), ("def", 222), ("tyu", 333), ("ewq", 888)]
df = spark.createDataFrame(data, ["Column1", "Column2"])
result_df = (
df.alias("a")
.join(df.alias("b"), F.col("a.Column1") < F.col("b.Column1"), "left") # this is more efficient than the "!=" approach used in other answer because using "<" creates less rows after join (almost half)
.join(df.alias("c"), (F.col("a.Column2") == F.col("c.Column2")) & (F.col("b.Column2") == F.col("c.Column2")), "left")
.where(F.col("a.Column1").isNotNull() & F.col("b.Column1").isNotNull())
.groupBy("a.Column1", "b.Column1")
.agg(
F.when(F.count("c.Column2") > 0, "CommonRow").otherwise("NoCommonRow").alias("CommonStatus")
)
.orderBy("a.Column1", "b.Column1") # remove this if not required
)
result_df.show(truncate=False)
# Output:
# +-------+-------+------------+
# |Column1|Column1|CommonStatus|
# +-------+-------+------------+
# |abc |def |CommonRow |
# |abc |ewq |NoCommonRow |
# |abc |tyu |CommonRow |
# |def |ewq |NoCommonRow |
# |def |tyu |NoCommonRow |
# |ewq |tyu |NoCommonRow |
# +-------+-------+------------+
Also, please avoid the orderBy
when you run this on your million-row df - use it only if necessary.