arraysapache-sparkpysparkapache-spark-sqlposition

Get index of item in array column in a Spark dataframe


I am able to filter a Spark dataframe (in PySpark) based on particular value existence within an array column by doing the following:

from pyspark.sql.functions import array_contains
spark_df.filter(array_contains(spark_df.array_column_name, "value that I want")) 

But is there a way to get the index of where in the array the item was found?


Solution

  • I am using spark 2.3 version, so I tried this using udf.

    df = spark.createDataFrame([(["c", "b", "a","e","f"],)], ['arraydata'])
    +---------------+
    |      arraydata|
    +---------------+
    |[c, b, a, e, f]|
    +---------------+
    
    user_func = udf (lambda x,y: [i for i, e in enumerate(x) if e==y ])
    

    checking index position for item 'b':

    newdf = df.withColumn('item_position',user_func(df.arraydata,lit('b')))
    
    >>> newdf.show();
    +---------------+-------------+
    |      arraydata|item_position|
    +---------------+-------------+
    |[c, b, a, e, f]|          [1]|
    +---------------+-------------+
    

    checking index position for item 'e':

    newdf = df.withColumn('item_position',user_func(df.arraydata,lit('e')))
    
    >>> newdf.show();
    +---------------+-------------+
    |      arraydata|item_position|
    +---------------+-------------+
    |[c, b, a, e, f]|          [3]|
    +---------------+-------------+