pythondataframepysparkdatabricks

How can I apply a JSON->PySpark nested dataframe as a mapping to another dataframe?


I have a JSON like this:

{"main":{"honda":1,"toyota":2,"BMW":5,"Fiat":4}}

I import into PySpark like this:

car_map = spark.read.json('s3_path/car_map.json')

Now I have a dataframe:

enter image description here

Given an existing dataframe:

data = [(1, 'BMW'),
  (2, 'Ford'),
  (3, 'honda'),
  (4, 'Cadillac'),
  (5, 'Fiat')]

df = spark.createDataFrame(data, ["ID", "car"])

+---+--------+
| ID|     car|
+---+--------+
|  1|     BMW|
|  2|    Ford|
|  3|   honda|
|  4|Cadillac|
|  5|    Fiat|
+---+--------+

How can I apply the mapping in car_map to df, creating a new column "x"? For example, if df.car is in car_map.main, then set x to the number. Else, set x to 99.

The result should be like so:

+---+--------+---+
| ID|     car|  x|
+---+--------+---+
|  1|     BMW|  5|
|  2|    Ford| 99|
|  3|   honda|  1|
|  4|Cadillac| 99|
|  5|    Fiat|  4|
+---+--------+---+

If there are other transformations to make this easier, I'm open. For example UDF, dictionary, array, explode, etc.


Solution

  • You can do this by creating a mapping table and do a left join on your base table.

    Depending on the size of the data_map table you can broadcast it to improve efficiency.

    import json
    from pyspark.sql import functions as f
    from pyspark.sql.types import MapType, StringType, IntegerType, StructType, StructField
    
    data = {"main":{"honda":1,"toyota":2,"BMW":5,"Fiat":4}}
    schema = StructType().add("main", MapType(StringType(), IntegerType()))
    
    data_map = [
        (5, 'BMW'),
        (1, 'honda'),
        (4, 'Fiat')
    ]
    schema_data_map = StructType(
        [
            StructField("x", IntegerType()),
            StructField("car", StringType())
        ]
    )
    
    
    df = (
        spark.read.json(rdd, schema=schema)
        .select(
            f.explode(
                f.col("main")
            ).alias("car", "ID")
        )
    )
    
    json_data = json.dumps([data])
    rdd = spark.sparkContext.parallelize([json_data])
    df_map = spark.createDataFrame(data_map, schema=schema_data_map)
    
    
    df_result = (
        df
        .join(
            df_map.hint("broadcast"),
            on=["car"],
            how="left"
        )
        .fillna(99, subset=["x"])
        .select("ID", "car", "x") # correct order
    )
    
    df_result:
    +---+------+---+
    | ID|   car|  x|
    +---+------+---+
    |  1| honda|  1|
    |  2|toyota| 99|
    |  5|   BMW|  5|
    |  4|  Fiat|  4|
    +---+------+---+