pythondataframeapache-sparkpysparkclosest

Find the closest value of each value in a column compared to another column in the same PySpark dataframe


We have a PySpark dataframe containing rate codes that we have to use to give discounted offers to our customers.

enter image description here

-ratecode - Actual rate code
-weeklyrate - weekly dollar amount that the customer will pay
-area - area of residence
-frequency -
-offer1 - The first discounted offer to customer
-offer2 - The second discounted offer to customer

The problem is to find the closest "ratecode" corresponding to "offer1" (and save it as "offer1Ratecode") and "offer2" (saving as "offer2Ratecode").

Explanation:

  1. for the "offer1" = 4.4 , the "offer1Ratecode" is R1, because the closest "weeklyrate" to 4.4 is 5.5 and 5.5 corresponds to "ratecode" R1
  2. for the "offer1" = 6 , the "offer1Ratecode" is R2, because the closest "weeklyrate" to 6 is 6.2 and 6.2 corresponds to "ratecode" R2

Solution

  • Input:

    df = spark.createDataFrame(
        [('R1', 5.5, 4.4, 3.85),
         ('R2', 6.2, 4.96, 4.34),
         ('R3', 7.5, 6.0, 5.25),
         ('R4', 5.6, 4.48, 3.92),
         ('R5', 7.3, 5.84, 5.11),
         ('R6', 8.4, 6.72, 5.88),
         ('R7', 9.1, 7.28, 6.37),
         ('R8', 6.8, 5.44, 4.76)],
        ['ratecode', 'weeklyrate', 'offer1', 'offer2'])
    

    One way would be using crossJoin and groupBy:

    from pyspark.sql import functions as F
    
    def closest(col):
        return F.array_sort(F.collect_list(F.struct(
            F.abs(F.col(f'b.{col}') - F.col('a.weeklyrate')).alias('diff'),
            'a.weeklyrate',
            'a.ratecode',
        )))[0]['ratecode'].alias(f'{col}Ratecode')
    
    df = df.select('weeklyrate', 'ratecode').alias('a').crossJoin(df.alias('b'))
    df = df.groupBy(*[f'b.{c}' for c in df.select('b.*').columns]).agg(
        closest('offer1'),
        closest('offer2'),
    )
    df.show()
    # +--------+----------+------+------+--------------+--------------+
    # |ratecode|weeklyrate|offer1|offer2|offer1Ratecode|offer2Ratecode|
    # +--------+----------+------+------+--------------+--------------+
    # |      R3|       7.5|   6.0|  5.25|            R2|            R1|
    # |      R2|       6.2|  4.96|  4.34|            R1|            R1|
    # |      R1|       5.5|   4.4|  3.85|            R1|            R1|
    # |      R4|       5.6|  4.48|  3.92|            R1|            R1|
    # |      R7|       9.1|  7.28|  6.37|            R5|            R2|
    # |      R6|       8.4|  6.72|  5.88|            R8|            R4|
    # |      R8|       6.8|  5.44|  4.76|            R1|            R1|
    # |      R5|       7.3|  5.84|  5.11|            R4|            R1|
    # +--------+----------+------+------+--------------+--------------+
    

    Another could be using window functions and transform:

    from pyspark.sql import functions as F, Window as W
    
    def closest(col):
        return F.array_sort(F.transform(
            F.collect_list(F.struct('weeklyrate', 'ratecode')).over(W.orderBy()),
            lambda x: F.struct(
                F.abs(F.col(col) - x['weeklyrate']).alias('diff'),
                x['weeklyrate'].alias('weeklyrate'),
                x['ratecode'].alias('ratecode'),
            )
        ))[0]['ratecode'].alias(f'{col}Ratecode')
    
    df = df.select('*', closest('offer1'), closest('offer2'))
    df.show()
    # +--------+----------+------+------+--------------+--------------+
    # |ratecode|weeklyrate|offer1|offer2|offer1Ratecode|offer2Ratecode|
    # +--------+----------+------+------+--------------+--------------+
    # |      R1|       5.5|   4.4|  3.85|            R1|            R1|
    # |      R2|       6.2|  4.96|  4.34|            R1|            R1|
    # |      R3|       7.5|   6.0|  5.25|            R2|            R1|
    # |      R4|       5.6|  4.48|  3.92|            R1|            R1|
    # |      R5|       7.3|  5.84|  5.11|            R4|            R1|
    # |      R6|       8.4|  6.72|  5.88|            R8|            R4|
    # |      R7|       9.1|  7.28|  6.37|            R5|            R2|
    # |      R8|       6.8|  5.44|  4.76|            R1|            R1|
    # +--------+----------+------+------+--------------+--------------+