apache-sparkdataframepysparkapache-spark-sqlapache-spark-ml

How to access element of a VectorUDT column in a Spark DataFrame?


I have a dataframe df with a VectorUDT column named features. How do I get an element of the column, say first element?

I've tried doing the following

from pyspark.sql.functions import udf
first_elem_udf = udf(lambda row: row.values[0])
df.select(first_elem_udf(df.features)).show()

but I get a net.razorvine.pickle.PickleException: expected zero arguments for construction of ClassDict(for numpy.dtype) error. Same error if I do first_elem_udf = first_elem_udf(lambda row: row.toArray()[0]) instead.

I also tried explode() but I get an error because it requires an array or map type.

This should be a common operation, I think.


Solution

  • Convert output to float:

    from pyspark.sql.types import DoubleType
    from pyspark.sql.functions import lit, udf
    
    def ith_(v, i):
        try:
            return float(v[i])
        except ValueError:
            return None
    
    ith = udf(ith_, DoubleType())
    

    Example usage:

    from pyspark.ml.linalg import Vectors
    
    df = sc.parallelize([
        (1, Vectors.dense([1, 2, 3])),
        (2, Vectors.sparse(3, [1], [9]))
    ]).toDF(["id", "features"])
    
    df.select(ith("features", lit(1))).show()
    
    ## +-----------------+
    ## |ith_(features, 1)|
    ## +-----------------+
    ## |              2.0|
    ## |              9.0|
    ## +-----------------+
    

    Explanation:

    Output values have to be reserialized to equivalent Java objects. If you want to access values (beware of SparseVectors) you should use item method:

    v.values.item(0)
    

    which return standard Python scalars. Similarly if you want to access all values as a dense structure:

    v.toArray().tolist()