scaladataframeapache-sparkpysparkapache-spark-sql

How to reduce multiple joins in spark


I am using Spark 2.4.1, to figure out some ratios on my data frame.

Where I need to find different ratio factors of ratios, different columns in given data frame(df_data) by joining to meta dataframe (i.e. resDs).

I am getting these ratio factors (i.e. ratio_1_factor, ratio_2_factor & ratio_3_factor) by using three different joins with different join conditions i.e. joinedDs , joinedDs2, joinedDs3

Is there any other alternative to reduce the number of joins? Make it work optimum?

You can find the entire sample data in the below public URL.

https://databricks-prod-cloudfront.cloud.databricks.com/public/4027ec902e239c93eaaa8714f173bcfc/1165111237342523/3521103084252405/7035720262824085/latest.html

How to handle multi-steps instead of single step in when clause:

.withColumn("step_1_ratio_1", (col("ratio_1").minus(lit(0.00000123))).cast(DataTypes.DoubleType)) // step-2
      .withColumn("step_2_ratio_1", (col("step_1_ratio_1").multiply(lit(0.02))).cast(DataTypes.DoubleType)) //step-3
      .withColumn("step_3_ratio_1", (col("step_2_ratio_1").divide(col("step_1_ratio_1"))).cast(DataTypes.DoubleType)) //step-4
      .withColumn("ratio_1_factor", (col("ratio_1_factor")).cast(DataTypes.DoubleType)) //step-5

i.e. "ratio_1_factor" calculated based on various other columns in the dataframe, df_data.

These steps -2,3,4 , are being used in other ratio_factors calculation too. i.e. ratio_2_factor, ratio_2_factor

How this should be handled?


Solution

  • You can join one time and calculate ratio_1_factor, ratio_2_factor and ratio_3_factor columns using max and when function in aggregation :

    val joinedDs = df_data.as("aa")
      .join(
        broadcast(resDs.as("bb")),
        col("aa.g_date").between(col("bb.start_date"), col("bb.end_date"))
      )
      .groupBy("item_id", "g_date", "ratio_1", "ratio_2", "ratio_3")
      .agg(
        max(when(
            col("aa.ratio_1").between(col("bb.A"), col("bb.A_lead")),
            col("ratio_1").multiply(lit(0.1))
          )
        ).cast(DoubleType).as("ratio_1_factor"),
        max(when(
            col("aa.ratio_2").between(col("bb.A"), col("bb.A_lead")),
            col("ratio_2").multiply(lit(0.2))
          )
        ).cast(DoubleType).as("ratio_2_factor"),
        max(when(
            col("aa.ratio_3").between(col("bb.A"), col("bb.A_lead")),
            col("ratio_3").multiply(lit(0.3))
          )
        ).cast(DoubleType).as("ratio_3_factor")
      )
    
    
    joinedDs.show(false)
    
    //+-------+----------+---------+-----------+-----------+---------------------+--------------+--------------+
    //|item_id|g_date    |ratio_1  |ratio_2    |ratio_3    |ratio_1_factor       |ratio_2_factor|ratio_3_factor|
    //+-------+----------+---------+-----------+-----------+---------------------+--------------+--------------+
    //|50312  |2016-01-04|0.0456646|0.046899415|0.046000415|0.0045664600000000005|0.009379883   |0.0138001245  |
    //+-------+----------+---------+-----------+-----------+---------------------+--------------+--------------+