
Find the closest value of each value in a column compared to another column in the same PySpark dataframe

We have a PySpark dataframe containing rate codes that we have to use to give discounted offers to our customers.

enter image description here

-ratecode - Actual rate code
-weeklyrate - weekly dollar amount that the customer will pay
-area - area of residence
-frequency -
-offer1 - The first discounted offer to customer
-offer2 - The second discounted offer to customer

The problem is to find the closest "ratecode" corresponding to "offer1" (and save it as "offer1Ratecode") and "offer2" (saving as "offer2Ratecode").


  1. for the "offer1" = 4.4 , the "offer1Ratecode" is R1, because the closest "weeklyrate" to 4.4 is 5.5 and 5.5 corresponds to "ratecode" R1
  2. for the "offer1" = 6 , the "offer1Ratecode" is R2, because the closest "weeklyrate" to 6 is 6.2 and 6.2 corresponds to "ratecode" R2


  • Input:

    df = spark.createDataFrame(
        [('R1', 5.5, 4.4, 3.85),
         ('R2', 6.2, 4.96, 4.34),
         ('R3', 7.5, 6.0, 5.25),
         ('R4', 5.6, 4.48, 3.92),
         ('R5', 7.3, 5.84, 5.11),
         ('R6', 8.4, 6.72, 5.88),
         ('R7', 9.1, 7.28, 6.37),
         ('R8', 6.8, 5.44, 4.76)],
        ['ratecode', 'weeklyrate', 'offer1', 'offer2'])

    One way would be using crossJoin and groupBy:

    from pyspark.sql import functions as F
    def closest(col):
        return F.array_sort(F.collect_list(F.struct(
            F.abs(F.col(f'b.{col}') - F.col('a.weeklyrate')).alias('diff'),
    df ='weeklyrate', 'ratecode').alias('a').crossJoin(df.alias('b'))
    df = df.groupBy(*[f'b.{c}' for c in'b.*').columns]).agg(
    # +--------+----------+------+------+--------------+--------------+
    # |ratecode|weeklyrate|offer1|offer2|offer1Ratecode|offer2Ratecode|
    # +--------+----------+------+------+--------------+--------------+
    # |      R3|       7.5|   6.0|  5.25|            R2|            R1|
    # |      R2|       6.2|  4.96|  4.34|            R1|            R1|
    # |      R1|       5.5|   4.4|  3.85|            R1|            R1|
    # |      R4|       5.6|  4.48|  3.92|            R1|            R1|
    # |      R7|       9.1|  7.28|  6.37|            R5|            R2|
    # |      R6|       8.4|  6.72|  5.88|            R8|            R4|
    # |      R8|       6.8|  5.44|  4.76|            R1|            R1|
    # |      R5|       7.3|  5.84|  5.11|            R4|            R1|
    # +--------+----------+------+------+--------------+--------------+

    Another could be using window functions and transform:

    from pyspark.sql import functions as F, Window as W
    def closest(col):
        return F.array_sort(F.transform(
            F.collect_list(F.struct('weeklyrate', 'ratecode')).over(W.orderBy()),
            lambda x: F.struct(
                F.abs(F.col(col) - x['weeklyrate']).alias('diff'),
    df ='*', closest('offer1'), closest('offer2'))
    # +--------+----------+------+------+--------------+--------------+
    # |ratecode|weeklyrate|offer1|offer2|offer1Ratecode|offer2Ratecode|
    # +--------+----------+------+------+--------------+--------------+
    # |      R1|       5.5|   4.4|  3.85|            R1|            R1|
    # |      R2|       6.2|  4.96|  4.34|            R1|            R1|
    # |      R3|       7.5|   6.0|  5.25|            R2|            R1|
    # |      R4|       5.6|  4.48|  3.92|            R1|            R1|
    # |      R5|       7.3|  5.84|  5.11|            R4|            R1|
    # |      R6|       8.4|  6.72|  5.88|            R8|            R4|
    # |      R7|       9.1|  7.28|  6.37|            R5|            R2|
    # |      R8|       6.8|  5.44|  4.76|            R1|            R1|
    # +--------+----------+------+------+--------------+--------------+