scalaapache-sparkapache-spark-sql

create a Spark DataFrame from a nested array of struct element?


I have read a JSON file into Spark. This file has the following structure:

 root
      |-- a: struct (nullable = true)
      |    |-- rt: array (nullable = true)
      |    |    |-- element: struct (containsNull = true)
      |    |    |    |-- rb: struct (nullable = true)
      |    |    |    |         |-- a: struct (nullable = true)
      |    |    |    |    |    |-- b: string (nullable = true)
      |    |    |    |    |    |-- c: boolean (nullable = true)
      |    |    |    |    |    |-- d: long (nullable = true)
      |    |    |    |    |    |-- e: string (nullable = true)

  

I created a recursive function to flatten the schema with columns that are of nested StructType

def flattenSchema(schema: StructType, prefix: String = null):Array[Column]= 
        {
        schema.fields.flatMap(f => {
          val colName = if (prefix == null) f.name else (prefix + "." + f.name)
    
          f.dataType match {
            case st: StructType => flattenSchema(st, colName)
            case _ => Array(col(colName).alias(colName))
          }
        })
        }
 
val newDF=df.select(flattenSchema(df.schema):_*)

val secondDF=newDF.toDF(newDF.columns.map(_.replace(".", "_")): _*)

How can i flatten the ArrayType that contain nested StructType for example engagementItems: array (nullable = true)

Any help is appreciated.


Solution

  • The problem here is that you need to manage the case for the ArrayType and after convert it into StructType. Therefore you can use the the Scala runtime conversion for that.

    First I generated the scenario as next (btw it would be very helpful to include this in your question since makes the reproduction of the problem much easier):

      case class DimapraUnit(code: String, constrained: Boolean, id: Long, label: String, ranking: Long, _type: String, version: Long, visible: Boolean)
      case class AvailabilityEngagement(dimapraUnit: DimapraUnit)
      case class Element(availabilityEngagement: AvailabilityEngagement)
      case class Engagement(engagementItems: Array[Element])
      case class root(engagement: Engagement)
      def getSchema(): StructType ={
        import org.apache.spark.sql.types._
        import org.apache.spark.sql.catalyst.ScalaReflection
        val schema = ScalaReflection.schemaFor[root].dataType.asInstanceOf[StructType]
    
        schema.printTreeString()
        schema
      }
    

    This will print out:

    root
     |-- engagement: struct (nullable = true)
     |    |-- engagementItems: array (nullable = true)
     |    |    |-- element: struct (containsNull = true)
     |    |    |    |-- availabilityEngagement: struct (nullable = true)
     |    |    |    |    |-- dimapraUnit: struct (nullable = true)
     |    |    |    |    |    |-- code: string (nullable = true)
     |    |    |    |    |    |-- constrained: boolean (nullable = false)
     |    |    |    |    |    |-- id: long (nullable = false)
     |    |    |    |    |    |-- label: string (nullable = true)
     |    |    |    |    |    |-- ranking: long (nullable = false)
     |    |    |    |    |    |-- _type: string (nullable = true)
     |    |    |    |    |    |-- version: long (nullable = false)
     |    |    |    |    |    |-- visible: boolean (nullable = false)
    

    Then I modified your function by adding an extra check for the ArrayType and converting it to StructType using asInstanceOf:

      import org.apache.spark.sql.types._  
      def flattenSchema(schema: StructType, prefix: String = null):Array[Column]=
      {
        schema.fields.flatMap(f => {
          val colName = if (prefix == null) f.name else (prefix + "." + f.name)
    
          f.dataType match {
            case st: StructType => flattenSchema(st, colName)
            case at: ArrayType =>
              val st = at.elementType.asInstanceOf[StructType]
              flattenSchema(st, colName)
            case _ => Array(new Column(colName).alias(colName))
          }
        })
      }
    

    And finally the results:

    val s = getSchema()
    val res = flattenSchema(s)
    
    res.foreach(println(_))
    

    Output:

    engagement.engagementItems.availabilityEngagement.dimapraUnit.code AS `engagement.engagementItems.availabilityEngagement.dimapraUnit.code`
    engagement.engagementItems.availabilityEngagement.dimapraUnit.constrained AS `engagement.engagementItems.availabilityEngagement.dimapraUnit.constrained`
    engagement.engagementItems.availabilityEngagement.dimapraUnit.id AS `engagement.engagementItems.availabilityEngagement.dimapraUnit.id`
    engagement.engagementItems.availabilityEngagement.dimapraUnit.label AS `engagement.engagementItems.availabilityEngagement.dimapraUnit.label`
    engagement.engagementItems.availabilityEngagement.dimapraUnit.ranking AS `engagement.engagementItems.availabilityEngagement.dimapraUnit.ranking`
    engagement.engagementItems.availabilityEngagement.dimapraUnit._type AS `engagement.engagementItems.availabilityEngagement.dimapraUnit._type`
    engagement.engagementItems.availabilityEngagement.dimapraUnit.version AS `engagement.engagementItems.availabilityEngagement.dimapraUnit.version`
    engagement.engagementItems.availabilityEngagement.dimapraUnit.visible AS `engagement.engagementItems.availabilityEngagement.dimapraUnit.visible`