apache-sparkpysparkapache-spark-sqlpyarrowfastparquet

Read/Write Parquet with Struct column type


I am trying to write a Dataframe like this to Parquet:

| foo | bar               |
|-----|-------------------|
|  1  | {"a": 1, "b": 10} |
|  2  | {"a": 2, "b": 20} |
|  3  | {"a": 3, "b": 30} |

I am doing it with Pandas and Fastparquet:

df = pd.DataFrame({
    "foo": [1, 2, 3],
    "bar": [{"a": 1, "b": 10}, {"a": 2, "b": 20}, {"a": 3, "b": 30}]
})

import fastparquet
fastparquet.write('/my/parquet/location/toy-fastparquet.parq', df)

I would like to load Parquet in (py)Spark, and query the data with Spark SQL, like:

df = spark.read.parquet("/my/parquet/location/")
df.registerTempTable('my_toy_table')
result = spark.sql("SELECT * FROM my_toy_table WHERE bar.b > 15")

My issue is that, even though fastparquet can read its Parquet file correctly (the bar field is correctly deserialized as a Struct), in Spark, bar is read as a column of type String, that just contains a JSON representation of the original structure:

In [2]: df.head()                                                                                                                                                                                           
Out[2]: Row(foo=1, bar='{"a": 1, "b": 10}')

I tried writing Parquet from PyArrow, but no luck there: ArrowNotImplementedError: Level generation for Struct not supported yet. I have also tried passing file_scheme='hive' to Fastparquet, but I got the same results. Changing Fastparquet serialization to BSON (object_encoding='bson') produced an unreadable binary field.

[EDIT] I see the following approaches:


Solution

  • You have at least 3 options here:

    Option 1:

    You don't need to use any extra libraries like fastparquet since Spark provides that functionality already:

    pdf = pd.DataFrame({
        "foo": [1, 2, 3],
        "bar": [{"a": 1, "b": 10}, {"a": 2, "b": 20}, {"a": 3, "b": 30}]
    })
    
    df = spark.createDataFrame(pdf)
    df.write.mode("overwrite").parquet("/tmp/parquet1")
    

    If try to load your data with df = spark.read.parquet("/tmp/parquet1") the schema will be:

    StructType([ 
                StructField("foo", LongType(), True),
                StructField("bar",MapType(StringType(), LongType(), True), True)])
    

    As you can see in this case Spark will retain the correct schema.

    Option 2:

    If for any reason still need to use fastparquet then bar will be treated as string therefore you can load bar as a string and then convert it to JSON using from_json function. In your case we will handle the json as a dictionary of Map(string, int). This is for our own convenience since the data seems to be a sequence of key/value which can be perfectly represented by a dictionary:

    from pyspark.sql.types import StringType, MapType,LongType
    from pyspark.sql.functions import from_json
    
    df = spark.read.parquet("/tmp/parquet1")
    
    # schema should be a Map(string, string) 
    df.withColumn("bar", from_json("bar", MapType(StringType(), LongType()))).show()
    
    # +---+-----------------+
    # |foo|              bar|
    # +---+-----------------+
    # |  1|[a -> 1, b -> 10]|
    # |  2|[a -> 2, b -> 20]|
    # |  3|[a -> 3, b -> 30]|
    # +---+-----------------+
    

    Option 3:

    If you your schema does not change and you know that each value of bar will always have the same combination of fields (a, b) you can also convert bar into a struct:

    schema = StructType([ 
                        StructField("a", LongType(), True),
                        StructField("b", LongType(), True)
                ])
    
    df = df.withColumn("bar", from_json("bar", schema))
    
    df.printSchema()
    
    # root
    #  |-- foo: long (nullable = true)
    #  |-- bar: struct (nullable = true)
    #  |    |-- a: long (nullable = true)
    #  |    |-- b: long (nullable = true)
    

    Example:

    Then you can run your code with:

    df.registerTempTable('my_toy_table')
    
    spark.sql("SELECT * FROM my_toy_table WHERE bar.b > 20").show()
    # or spark.sql("SELECT * FROM my_toy_table WHERE bar['b'] > 20")
    
    # +---+-----------------+
    # |foo|              bar|
    # +---+-----------------+
    # |  3|[a -> 3, b -> 30]|
    # +---+-----------------+