pythonapache-sparkstructpyspark

How to sort columns of nested structs alphabetically in pyspark?


I have data with below schema. I want all the columns should be in sorted alphabetically. I want it in pyspark data frame.

root
 |-- _id: string (nullable = true)
 |-- first_name: string (nullable = true)
 |-- last_name: string (nullable = true)
 |-- address: struct (nullable = true)
 |    |-- pin: integer (nullable = true)
 |    |-- city: string (nullable = true)
 |    |-- street: string (nullable = true)

The below code sorts only the outer columns but not the nested columns.

>>> cols = df.columns
>>> df2=df[sorted(cols)]
>>> df2.printSchema()

The schema after this code looks like

root
 |-- _id: string (nullable = true)
 |-- address: struct (nullable = true)
 |    |-- pin: integer (nullable = true)
 |    |-- city: string (nullable = true)
 |    |-- street: string (nullable = true)
 |-- first_name: string (nullable = true)
 |-- last_name: string (nullable = true)

(since there's underscore at id, it appears first)

The schema which I want is as below. (Even the columns inside the address should be sorted)

root
 |-- _id: string (nullable = true)
 |-- address: struct (nullable = true)
 |    |-- city: string (nullable = true)
 |    |-- pin: integer (nullable = true)
 |    |-- street: string (nullable = true)
 |-- first_name: string (nullable = true)
 |-- last_name: string (nullable = true)

Thanks in advance.


Solution

  • Here is a solution that should work for arbitrarily deeply nested StructTypes which doesn't rely on hard coding any column names.

    To demonstrate, I've created the following slightly more complex schema, where there is a second level of nesting within the address column. Let's suppose that your DataFrame schema were the following:

    df.printSchema()
    #root
    # |-- _id: string (nullable = true)
    # |-- first_name: string (nullable = true)
    # |-- last_name: string (nullable = true)
    # |-- address: struct (nullable = true)
    # |    |-- pin: integer (nullable = true)
    # |    |-- city: string (nullable = true)
    # |    |-- zip: struct (nullable = true)
    # |    |    |-- last4: integer (nullable = true)
    # |    |    |-- first5: integer (nullable = true)
    # |    |-- street: string (nullable = true)
    

    Notice the address.zip field which contains 2 out of order sub fields.

    You can define a function that will recursively step through your schema and sort the fields to build a Spark-SQL select expression:

    from pyspark.sql.types import StructType, StructField
    
    def schemaToSelectExpr(schema, baseField=""):
        select_cols = []
        for structField in sorted(schema, key=lambda x: x.name):
            if structField.dataType.typeName() == 'struct':
    
                subFields = []
                for fld in sorted(structField.jsonValue()['type']['fields'], 
                                  key=lambda x: x['name']):
                    newStruct = StructType([StructField.fromJson(fld)])
                    newBaseField = structField.name
                    if baseField:
                        newBaseField = baseField + "." + newBaseField
                    subFields.extend(schemaToSelectExpr(newStruct, baseField=newBaseField))
    
                select_cols.append(
                    "struct(" + ",".join(subFields) + ") AS {}".format(structField.name)
                )
            else:
                if baseField:
                    select_cols.append(baseField + "." + structField.name)
                else:
                    select_cols.append(structField.name)
        return select_cols
    

    Running this on this DataFrame's schema yields (I've broken the long 'address' string into two lines for readability):

    print(schemaToSelectExpr(df.schema))
    #['_id',
    #'struct(address.city,address.pin,address.street,
    #        struct(address.zip.first5,address.zip.last4) AS zip) AS address',
    # 'first_name',
    # 'last_name']
    

    Now use selectExpr to sort the columns:

    df = df.selectExpr(schemaToSelectExpr(df.schema))
    df.printSchema()
    #root
    # |-- _id: string (nullable = true)
    # |-- address: struct (nullable = false)
    # |    |-- city: string (nullable = true)
    # |    |-- pin: integer (nullable = true)
    # |    |-- street: string (nullable = true)
    # |    |-- zip: struct (nullable = false)
    # |    |    |-- first5: integer (nullable = true)
    # |    |    |-- last4: integer (nullable = true)
    # |-- first_name: string (nullable = true)
    # |-- last_name: string (nullable = true)