apache-sparkpysparkapache-spark-sqlprobabilitymarkov-chains

Calculate a sequence of Markov chain values


I have a Spark question, so for the input for each entity k I have a sequence of probability p_i with a value associated v_i, for example the data can look like this

entity | Probability | value
A      | 0.8         | 10
A      | 0.6         | 15
A      | 0.3         | 20
B      | 0.8         | 10

Then, for entity A, I'm expecting the avg value to be 0.8*10 + (1-0.8)*0.6*15 + (1-0.8)*(1-0.6)*0.3*20 + (1-0.8)*(1-0.6)*(1-0.3)*MAX_VALUE_DEFINED.

How could I achieve this in Spark using DataFrame agg func? I found it's challenging given the complexity to groupBy entity and compute the sequence of results.


Solution

  • You can use UDF to perform such custom calculations. The idea is using collect_list to group all probab and values of A into one place so you can loop through it. However, collect_list does not respect the order of your records, therefore might lead to the wrong calculation. One way to fix it is generating ID for each row using monotonically_increasing_id

    import pyspark.sql.functions as F
    
    @F.pandas_udf('double')
    def markov_udf(values):
        def markov(lst):
            # you can implement your markov logic here
            s = 0
            for i, prob, val in lst:
                s += prob
            return s
        return values.apply(markov)
        
    (df
        .withColumn('id', F.monotonically_increasing_id())
        .groupBy('entity')
        .agg(F.array_sort(F.collect_list(F.array('id', 'probability', 'value'))).alias('values'))
        .withColumn('markov', markov_udf('values'))
        .show(10, False)
    )
    
    +------+------------------------------------------------------+------+
    |entity|values                                                |markov|
    +------+------------------------------------------------------+------+
    |B     |[[3.0, 0.8, 10.0]]                                    |0.8   |
    |A     |[[0.0, 0.8, 10.0], [1.0, 0.6, 15.0], [2.0, 0.3, 20.0]]|1.7   |
    +------+------------------------------------------------------+------+