pythonapache-sparkpysparkrdd

How to get the index of the highest value in a list per row in a Spark DataFrame? [PySpark]


I have done LDA topic modelling and have it stored in lda_model.

After transforming my original input dataset I retrieve a DataFrame. One of the columns is the topicDistribution where the probability of this row belonging to each topic from the LDA model. I therefore want to get the index of the maximul value in the list per row.

df -- | 'list_of_words' | 'index ' | 'topicDistribution' | 
       ['product','...']     0       [0.08,0.2,0.4,0.0001]
          .....             ...         ........

I want to transform df such that an additional column is added which is the argmax of the topicDistribution list per row.

df_transformed --  | 'list_of_words' | 'index' | 'topicDistribution' | 'topicID' |
                    ['product','...']     0     [0.08,0.2,0.4,0.0001]      2
                       ......            ....         .....              ....

How would I do this?


Solution

  • You can create a user defined function to get the index of the maximum

    from pyspark.sql import functions as f
    from pyspark.sql.types import IntegerType
    
    max_index = f.udf(lambda x: x.index(max(x)), IntegerType())
    df = df.withColumn("topicID", max_index("topicDistribution"))
    

    Example

    >>> from pyspark.sql import functions as f
    >>> from pyspark.sql.types import IntegerType 
    >>> df = spark.createDataFrame([{"topicDistribution": [0.2, 0.3, 0.5]}])
    >>> df.show()
    +-----------------+
    |topicDistribution|
    +-----------------+
    |  [0.2, 0.3, 0.5]|
    +-----------------+
    
    >>> max_index = f.udf(lambda x: x.index(max(x)), IntegerType())
    >>> df.withColumn("topicID", max_index("topicDistribution")).show()
    +-----------------+-------+
    |topicDistribution|topicID|
    +-----------------+-------+
    |  [0.2, 0.3, 0.5]|      2|
    +-----------------+-------+
    

    Edit:

    Since you mentioned that the lists in topicDistribution are numpy arrays, you can update the max_index udf as follows:

    max_index = f.udf(lambda x: x.tolist().index(max(x)), IntegerType())