I have done LDA topic modelling and have it stored in lda_model
.
After transforming my original input dataset I retrieve a DataFrame. One of the columns is the topicDistribution where the probability of this row belonging to each topic from the LDA model. I therefore want to get the index of the maximul value in the list per row.
df -- | 'list_of_words' | 'index ' | 'topicDistribution' |
['product','...'] 0 [0.08,0.2,0.4,0.0001]
..... ... ........
I want to transform df such that an additional column is added which is the argmax of the topicDistribution list per row.
df_transformed -- | 'list_of_words' | 'index' | 'topicDistribution' | 'topicID' |
['product','...'] 0 [0.08,0.2,0.4,0.0001] 2
...... .... ..... ....
How would I do this?
You can create a user defined function to get the index of the maximum
from pyspark.sql import functions as f
from pyspark.sql.types import IntegerType
max_index = f.udf(lambda x: x.index(max(x)), IntegerType())
df = df.withColumn("topicID", max_index("topicDistribution"))
Example
>>> from pyspark.sql import functions as f
>>> from pyspark.sql.types import IntegerType
>>> df = spark.createDataFrame([{"topicDistribution": [0.2, 0.3, 0.5]}])
>>> df.show()
+-----------------+
|topicDistribution|
+-----------------+
| [0.2, 0.3, 0.5]|
+-----------------+
>>> max_index = f.udf(lambda x: x.index(max(x)), IntegerType())
>>> df.withColumn("topicID", max_index("topicDistribution")).show()
+-----------------+-------+
|topicDistribution|topicID|
+-----------------+-------+
| [0.2, 0.3, 0.5]| 2|
+-----------------+-------+
Edit:
Since you mentioned that the lists in topicDistribution
are numpy arrays, you can update the max_index
udf
as follows:
max_index = f.udf(lambda x: x.tolist().index(max(x)), IntegerType())