apache-sparkpyspark

Pyspark : Joins in for loop


I have 2 dataframes product and categories

products_df:
+------+------+------+-----------+
|region|laptop|mobile|conditioner|
+------+------+------+-----------+
| North|  L123|  M456|       C789|
|  West|  NULL|  M789|       C123|
|  NULL|  L456|  M123|       C456|
+------+------+------+-----------+

categories_df:
+------+--------------------+------------+-----------+
|region|        product_name|product_code|      class|
+------+--------------------+------------+-----------+
| North|      Laptop Model X|        L123|electronics|
| North|      Mobile Model Y|        M456|electronics|
|  NULL|      Mobile Model Z|        M789|electronics|
|  NULL|      Laptop Model Z|        L456|electronics|
|  NULL|      Mobile Model A|        M123|electronics|
|  West|  Conditioner Deluxe|        C123| appliances|
| North|     Conditioner Pro|        C789| appliances|
|  NULL|Conditioner Standard|        C456| appliances|
+------+--------------------+------------+-----------+

product_df contains product in column (laptop, mobile, conditioner) and their values are its product codes. Requirement is, we need to replace product code with its name.

For each product code, we need to get its mapped product name from categories_df based on the product class (electronics, appliances) and region.

Just like product_df we can have other dataframes whose product column are different, but they are mapped to categories_df. So for that I've created below UDF which takes the dynamic input of product dataframes and the constant categories_df. For mapping the product with categories, I am passing the dictionary as map of product column and class of category column:

'''
product_class_mapping = {
    "laptop": "electronics",
    "mobile": "electronics",
    "conditioner": "appliances",
}
'''
def set_product_name(products_df, categories_df, product_class_mapping):
    for product_col, product_class in product_class_mapping.items():
        products_df = (
            products_df.alias("p")
            .join(
                categories_df.alias("c"),
                (F.col("c.class") == product_class)
                & (F.col(f"p.{product_col}") == F.col("c.product_code"))
                & (F.col("c.region").isNull() | (F.col("p.region") == F.col("c.region"))),
                how="left",
            )
            .select(F.col("p.*"), F.col("c.product_name"))
            .withColumn(
                product_col,
                F.when(F.col(product_col).isNotNull(), F.col("product_name")).otherwise(
                    F.col(product_col)
                ),
            ).drop("product_name")
        )
        
    print("Final mapped products_df:")
    products_df.show()
    return products_df

it generates the expected output where product_code gets replaced with product_name:

Final mapped products_df:
+------+--------------+--------------+--------------------+
|region|        laptop|        mobile|         conditioner|
+------+--------------+--------------+--------------------+
| North|Laptop Model X|Mobile Model Y|     Conditioner Pro|
|  West|          NULL|Mobile Model Z|  Conditioner Deluxe|
|  NULL|Laptop Model Z|Mobile Model A|Conditioner Standard|
+------+--------------+--------------+--------------------+

However, I doubt will this be an optimal solution? As we are using for loop and performing joins. Any alternative can we consider?


Solution

  • Melt the products dataframe

    df = products_df.melt(
        ids=['region'],
        values=list(product_class_mapping),
        variableColumnName='product_name',
        valueColumnName='product_code'
    )
    
    # df.show()
    # +------+------------+------------+
    # |region|product_name|product_code|
    # +------+------------+------------+
    # | North|      laptop|        L123|
    # | North|      mobile|        M456|
    # | North| conditioner|        C789|
    # |  West|      laptop|        null|
    # |  West|      mobile|        M789|
    # |  West| conditioner|        C123|
    # |  null|      laptop|        L456|
    # |  null|      mobile|        M123|
    # |  null| conditioner|        C456|
    # +------+------------+------------+
    

    Join with products mapping to yank the class

    mapping = spark.createDataFrame(product_class_mapping.items(), ['product_name', 'class'])
    df = df.join(F.broadcast(mapping), on='product_name', how='left')
    
    # df.show()
    # +------------+------+------------+-----------+
    # |product_name|region|product_code|      class|
    # +------------+------+------------+-----------+
    # |      laptop| North|        L123|electronics|
    # |      mobile| North|        M456|electronics|
    # | conditioner| North|        C789| appliances|
    # |      laptop|  West|        null|electronics|
    # |      mobile|  West|        M789|electronics|
    # | conditioner|  West|        C123| appliances|
    # |      laptop|  null|        L456|electronics|
    # |      mobile|  null|        M123|electronics|
    # | conditioner|  null|        C456| appliances|
    # +------------+------+------------+-----------+
    

    Create a join condition based on mapping logic and join with category df

    join_cond = (
        (
            (df['region'] == category_df['region'])
            | df['region'].isNull() | category_df['region'].isNull()
        )
        & (df['class'] == category_df['class'])
        & (df['product_code'] == category_df['product_code'])
    )
    df = df.alias('l').join(category_df.alias('r'), on=join_cond)
    df = df.select('l.region', 'l.product_name', 'r.product_name')
    
    # df.show()
    # +------+------------+--------------------+
    # |region|product_name|        product_name|
    # +------+------------+--------------------+
    # |  West| conditioner|  Conditioner Deluxe|
    # |  null| conditioner|Conditioner Standard|
    # | North| conditioner|     Conditioner Pro|
    # | North|      laptop|      Laptop Model X|
    # |  null|      laptop|      Laptop Model Z|
    # |  null|      mobile|      Mobile Model A|
    # | North|      mobile|      Mobile Model Y|
    # |  West|      mobile|      Mobile Model Z|
    # +------+------------+--------------------+
    

    Reshape the dataframe using pivot

    df = df.groupBy('region').pivot('l.product_name').agg(F.first('r.product_name'))
    
    # df.show()
    # +------+--------------------+--------------+--------------+
    # |region|         conditioner|        laptop|        mobile|
    # +------+--------------------+--------------+--------------+
    # |  null|Conditioner Standard|Laptop Model Z|Mobile Model A|
    # | North|     Conditioner Pro|Laptop Model X|Mobile Model Y|
    # |  West|  Conditioner Deluxe|          null|Mobile Model Z|
    # +------+--------------------+--------------+--------------+