apache-sparkpysparkapache-spark-sqlmetadataapache-spark-ml

How to attach metadata to a double column in PySpark


I have a double-typed column in a dataframe that holds the class label for a Random Forest training set.
I would like to manually attach metadata to the column so that I don't have to pass the dataframe into a StringIndexer as suggested in another question.
The easiest method of doing this seems to be by using the as method of Column.
However, this method is not available in Python.

Is there an easy workaround?

If there is no easy workaround and the best approach is a Python port of as, then why is the method not ported in Python?
Is there a difficult technical reason and not simply because it conflicts with the as keyword in Python and that no one has volunteered to port it?

I looked at the source code and found that the alias method in Python internally calls the as method in Scala.


Solution

  • import json
    from pyspark.sql.column import Column
    
    def add_meta(col, metadata):
        meta = sc._jvm.org.apache.spark.sql.types\
                 .Metadata.fromJson(json.dumps(metadata))
        return Column(getattr(col._jc, "as")('', meta))
    
    # sample invocation
    df.withColumn('label', 
                   add_meta(df.classification, 
                            {"ml_attr": {
                                 "name": "label", 
                                 "type": "nominal", 
                                 "vals": ["0.0", "1.0"]
                                    }
                            }))\
      .show()
    

    This solution involves calling the as(alias: String, metadata: Metadata) Scala method in Python. It can be retrieved by getattr(col._jc, "as") where col is a dataframe column (Column object).

    This returned function must then be called with two arguments. The first argument is just a string and the second argument is a Metadata. The object is created by calling Metadata.fromJson() which expects a JSON string as parameter. The method is retrieved via the _jvm attribute of the Spark context.