apache-sparkpysparkapache-spark-sqloutliersspark-window-function

How to run user defined function over a window in spark dataframe?


I am trying to detect the outliers from my spark dataframe. Below is the data sample.

pressure Timestamp
358.64 2022-01-01 00:00:00
354.98 2022-01-01 00:10:00
350.34 2022-01-01 00:20:00
429.69 2022-01-01 00:30:00
420.41 2022-01-01 00:40:00
413.82 2022-01-01 00:50:00
409.42 2022-01-01 01:00:00
409.67 2022-01-01 01:10:00
413.33 2022-01-01 01:20:00
405.03 2022-01-01 01:30:00
1209.42 2022-01-01 01:40:00
405.03 2022-01-01 01:50:00
404.54 2022-01-01 02:00:00
405.27 2022-01-01 02:10:00
999.45 2022-01-01 02:20:00
362.79 2022-01-01 02:30:00
349.37 2022-01-01 02:40:00
356.2 2022-01-01 02:50:00
3200.23 2022-01-01 03:00:00
348.39 2022-01-01 03:10:00

Here is my function to find out outliers for entire dataset

def outlierDetection(df): 
   inter_quantile_range = df.approxQuantile("pressure",[0.20,0.80],relativeError=0)
    
    Q1=inter_quantile_range[0]
    Q3=inter_quantile_range[1]
        
    inter_quantile_diff = Q3 - Q1

    minimum_Q1 =  Q1 - 1.5 * inter_quantile_diff
    maximum_Q3 =  Q3 + 1.5 * inter_quantile_diff

    df= df.withColumn("isOutlier",F.when((df["pressure"] > maximum_Q3) | (df["pressure"] < minimum_Q1), 1).otherwise(0))
    return df

It is working as expected. but it is considering the outliers for all the values which doesn't fit in the range.

I want to check outlier present for each hourly interval.

I have created another column which has hourly value as follows

pressure Timestamp date_hour
358.64 2022-01-01 00:00:00 2022-01-01 00
354.98 2022-01-01 00:10:00 2022-01-01 00
350.34 2022-01-01 00:20:00 2022-01-01 00
429.69 2022-01-01 00:30:00 2022-01-01 00
420.41 2022-01-01 00:40:00 2022-01-01 00
413.82 2022-01-01 00:50:00 2022-01-01 00
409.42 2022-01-01 01:00:00 2022-01-01 01
409.67 2022-01-01 01:10:00 2022-01-01 01
413.33 2022-01-01 01:20:00 2022-01-01 01
405.03 2022-01-01 01:30:00 2022-01-01 01

I am trying to create a window like below.

w1= Window.partitionBy("date_hour").orderBy("Timestamp")

Is there any way to use my function over each window in the dataframe?


Solution

  • If you're using spark 3.1+, you can use percentile_approx to calculate the quantiles, and do rest of the calculations in pyspark. In case your spark version does not have that function, we can use an UDF that uses numpy.quantile for the quantile calculation. I've shown both in the code.

    data_sdf = spark.sparkContext.parallelize(data_ls).toDF(['pressure', 'ts']). \
        withColumn('ts', func.col('ts').cast('timestamp')). \
        withColumn('dt_hr', func.date_format('ts', 'yyyyMMddHH'))
    
    # +--------+-------------------+----------+
    # |pressure|                 ts|     dt_hr|
    # +--------+-------------------+----------+
    # |  358.64|2022-01-01 00:00:00|2022010100|
    # |  354.98|2022-01-01 00:10:00|2022010100|
    # |  350.34|2022-01-01 00:20:00|2022010100|
    # |  429.69|2022-01-01 00:30:00|2022010100|
    # |  420.41|2022-01-01 00:40:00|2022010100|
    # |  413.82|2022-01-01 00:50:00|2022010100|
    # |  409.42|2022-01-01 01:00:00|2022010101|
    # |  409.67|2022-01-01 01:10:00|2022010101|
    # |  413.33|2022-01-01 01:20:00|2022010101|
    # |  405.03|2022-01-01 01:30:00|2022010101|
    # | 1209.42|2022-01-01 01:40:00|2022010101|
    # |  405.03|2022-01-01 01:50:00|2022010101|
    # +--------+-------------------+----------+
    

    getting the quantiles (showing both methods; use whichever is available)

    # spark 3.1+ has percentile_approx
    pressure_quantile_sdf = data_sdf. \
        groupBy('dt_hr'). \
        agg(func.percentile_approx('pressure', [0.2, 0.8]).alias('quantile_20_80'))
    
    # +----------+----------------+
    # |     dt_hr|  quantile_20_80|
    # +----------+----------------+
    # |2022010100|[354.98, 420.41]|
    # |2022010101|[405.03, 413.33]|
    # +----------+----------------+
    
    # lower versions use UDF
    def numpy_quantile_20_80(list_col):
        import numpy as np
    
        q_20 = np.quantile(list_col, 0.2)
        q_80 = np.quantile(list_col, 0.8)
    
        return [float(q_20), float(q_80)]
    
    numpy_quantile_20_80_udf = func.udf(numpy_quantile_20_80, ArrayType(FloatType()))
    
    pressure_quantile_sdf = data_sdf. \
        groupBy('dt_hr'). \
        agg(func.collect_list('pressure').alias('pressure_list')). \
        withColumn('quantile_20_80', numpy_quantile_20_80_udf(func.col('pressure_list')))
    
    # +----------+--------------------+----------------+
    # |     dt_hr|       pressure_list|  quantile_20_80|
    # +----------+--------------------+----------------+
    # |2022010100|[358.64, 354.98, ...|[354.98, 420.41]|
    # |2022010101|[409.42, 409.67, ...|[405.03, 413.33]|
    # +----------+--------------------+----------------+
    

    outlier calculation would be easy with the quantile info

    pressure_quantile_sdf = pressure_quantile_sdf. \
        withColumn('quantile_20', func.col('quantile_20_80')[0]). \
        withColumn('quantile_80', func.col('quantile_20_80')[1]). \
        withColumn('min_q_20', func.col('quantile_20') - 1.5 * (func.col('quantile_80') - func.col('quantile_20'))). \
        withColumn('max_q_80', func.col('quantile_80') + 1.5 * (func.col('quantile_80') - func.col('quantile_20'))). \
        select('dt_hr', 'min_q_20', 'max_q_80')
    
    # +----------+------------------+------------------+
    # |     dt_hr|          min_q_20|          max_q_80|
    # +----------+------------------+------------------+
    # |2022010100|256.83502197265625| 518.5549926757812|
    # |2022010101|392.58001708984375|425.77996826171875|
    # +----------+------------------+------------------+
    
    # outlier calc -- select columns that are required
    data_sdf. \
        join(pressure_quantile_sdf, 'dt_hr', 'left'). \
        withColumn('is_outlier', ((func.col('pressure') > func.col('max_q_80')) | (func.col('pressure') < func.col('min_q_20'))).cast('int')). \
        show()
    
    # +----------+--------+-------------------+------------------+------------------+----------+
    # |     dt_hr|pressure|                 ts|          min_q_20|          max_q_80|is_outlier|
    # +----------+--------+-------------------+------------------+------------------+----------+
    # |2022010100|  358.64|2022-01-01 00:00:00|256.83502197265625| 518.5549926757812|         0|
    # |2022010100|  354.98|2022-01-01 00:10:00|256.83502197265625| 518.5549926757812|         0|
    # |2022010100|  350.34|2022-01-01 00:20:00|256.83502197265625| 518.5549926757812|         0|
    # |2022010100|  429.69|2022-01-01 00:30:00|256.83502197265625| 518.5549926757812|         0|
    # |2022010100|  420.41|2022-01-01 00:40:00|256.83502197265625| 518.5549926757812|         0|
    # |2022010100|  413.82|2022-01-01 00:50:00|256.83502197265625| 518.5549926757812|         0|
    # |2022010101|  409.42|2022-01-01 01:00:00|392.58001708984375|425.77996826171875|         0|
    # |2022010101|  409.67|2022-01-01 01:10:00|392.58001708984375|425.77996826171875|         0|
    # |2022010101|  413.33|2022-01-01 01:20:00|392.58001708984375|425.77996826171875|         0|
    # |2022010101|  405.03|2022-01-01 01:30:00|392.58001708984375|425.77996826171875|         0|
    # |2022010101| 1209.42|2022-01-01 01:40:00|392.58001708984375|425.77996826171875|         1|
    # |2022010101|  405.03|2022-01-01 01:50:00|392.58001708984375|425.77996826171875|         0|
    # +----------+--------+-------------------+------------------+------------------+----------+