scalaapache-sparkdataframeapache-spark-sqlapache-spark-ml

Dropping a nested column from Spark DataFrame


I have a DataFrame with the schema

root
 |-- label: string (nullable = true)
 |-- features: struct (nullable = true)
 |    |-- feat1: string (nullable = true)
 |    |-- feat2: string (nullable = true)
 |    |-- feat3: string (nullable = true)

While, I am able to filter the data frame using

  val data = rawData
     .filter( !(rawData("features.feat1") <=> "100") )

I am unable to drop the columns using

  val data = rawData
       .drop("features.feat1")

Is it something that I am doing wrong here? I also tried (unsuccessfully) doing drop(rawData("features.feat1")), though it does not make much sense to do so.

Thanks in advance,

Nikhil


Solution

  • It is just a programming exercise but you can try something like this:

    import org.apache.spark.sql.{DataFrame, Column}
    import org.apache.spark.sql.types.{StructType, StructField}
    import org.apache.spark.sql.{functions => f}
    import scala.util.Try
    
    case class DFWithDropFrom(df: DataFrame) {
      def getSourceField(source: String): Try[StructField] = {
        Try(df.schema.fields.filter(_.name == source).head)
      }
    
      def getType(sourceField: StructField): Try[StructType] = {
        Try(sourceField.dataType.asInstanceOf[StructType])
      }
    
      def genOutputCol(names: Array[String], source: String): Column = {
        f.struct(names.map(x => f.col(source).getItem(x).alias(x)): _*)
      }
    
      def dropFrom(source: String, toDrop: Array[String]): DataFrame = {
        getSourceField(source)
          .flatMap(getType)
          .map(_.fieldNames.diff(toDrop))
          .map(genOutputCol(_, source))
          .map(df.withColumn(source, _))
          .getOrElse(df)
      }
    }
    

    Example usage:

    scala> case class features(feat1: String, feat2: String, feat3: String)
    defined class features
    
    scala> case class record(label: String, features: features)
    defined class record
    
    scala> val df = sc.parallelize(Seq(record("a_label",  features("f1", "f2", "f3")))).toDF
    df: org.apache.spark.sql.DataFrame = [label: string, features: struct<feat1:string,feat2:string,feat3:string>]
    
    scala> DFWithDropFrom(df).dropFrom("features", Array("feat1")).show
    +-------+--------+
    |  label|features|
    +-------+--------+
    |a_label| [f2,f3]|
    +-------+--------+
    
    
    scala> DFWithDropFrom(df).dropFrom("foobar", Array("feat1")).show
    +-------+----------+
    |  label|  features|
    +-------+----------+
    |a_label|[f1,f2,f3]|
    +-------+----------+
    
    
    scala> DFWithDropFrom(df).dropFrom("features", Array("foobar")).show
    +-------+----------+
    |  label|  features|
    +-------+----------+
    |a_label|[f1,f2,f3]|
    +-------+----------+
    

    Add an implicit conversion and you're good to go.