palantir-foundryfoundry-code-repositoriesfoundry-python-transform

How do I union many distinct schemas into a single output I can dynamically pivot later?


I want to take an arbitrary set of schemas and combine them into a single dataset that can be unpivoted later. What is the most stable way to do this?

Let's suppose I have tens of inputs that have different schemas. These all have different columns that actually mean the same thing, they just don't have the same names.

After I fix the column names, I want to create outputs that are collections of these columns, i.e. dynamic pivoting. I want my inputs after combined together to have a single schema I can do this dynamic pivoting on top of later.

I really don't want to go through and make clones of the inputs or create hundreds of intermediate datasets, so how is best I do this?


Solution

  • One strategy you can use to harmonize arbitrary datasets schemas is by using the cell-level technique where you make one long dataset noting each cell's contents.

    Our strategy here will boil down to the following:

    1. For each input, construct an index (row number) indicating each unique row
    2. Melt the dataset, noting the original column name and cell value for each row
    3. Stack all melted datasets into one long dataset
    4. Change the column names (now, these are rows) in the long dataset to the desired aliases
    5. Pivot the long dataset to a final, harmonized dataset

    NOTE: The reaon we pursue this strategy instead of using a bunch of .withColumn calls where we directly rename the columns of our dataframes is to allow more dynamic un-pivoting downstream in our pipeline.

    If we imagine consumers downstream may want to describe their own alias rules or combinations of columns, we need one long dataset of all cells, columns, and row numbers of origination to pull from.

    If we don't need dynamic pivoting, a simpler strategy of taking each input dataset and renaming columns would be fine.

    Our setup creating the example DataFrames looks like this:

    from pyspark.sql import types as T, functions as F, window, SparkSession
    
    spark = SparkSession.builder.getOrCreate()
    
    
    # Synthesize DataFrames
    schema1 = T.StructType([
      T.StructField("col_1_alias_1", T.StringType(), False),
      T.StructField("col_2", T.IntegerType(), False),
      T.StructField("col_3", T.StringType(), False),
      T.StructField("col_4", T.IntegerType(), False),
    ])
    data1 = [
      {"col_1_alias_1": "key_1", "col_2": 1, "col_3": "CREATE", "col_4": 0},
      {"col_1_alias_1": "key_2", "col_2": 2, "col_3": "CREATE", "col_4": 0},
      {"col_1_alias_1": "key_3", "col_2": 3, "col_3": "CREATE", "col_4": 0},
      {"col_1_alias_1": "key_1", "col_2": 1, "col_3": "UPDATE", "col_4": 1},
      {"col_1_alias_1": "key_2", "col_2": 2, "col_3": "UPDATE", "col_4": 1},
      {"col_1_alias_1": "key_1", "col_2": 1, "col_3": "DELETE", "col_4": 2},
    ]
    
    df1 = spark.createDataFrame(data1, schema1)
    
    df1 = df1.withColumn("origin", F.lit("df1"))
    
    
    
    schema2 = T.StructType([
      T.StructField("col_1_alias_2", T.StringType(), False),
      T.StructField("col_2", T.IntegerType(), False),
      T.StructField("col_3", T.StringType(), False),
      T.StructField("col_4", T.IntegerType(), False),
    ])
    data2 = [
      {"col_1_alias_2": "key_1", "col_2": 1, "col_3": "CREATE", "col_4": 0},
      {"col_1_alias_2": "key_2", "col_2": 2, "col_3": "CREATE", "col_4": 0},
      {"col_1_alias_2": "key_3", "col_2": 3, "col_3": "CREATE", "col_4": 0},
      {"col_1_alias_2": "key_1", "col_2": 1, "col_3": "UPDATE", "col_4": 1},
      {"col_1_alias_2": "key_2", "col_2": 2, "col_3": "UPDATE", "col_4": 1},
      {"col_1_alias_2": "key_1", "col_2": 1, "col_3": "DELETE", "col_4": 2},
    ]
    
    df2 = spark.createDataFrame(data2, schema2)
    
    df2 = df2.withColumn("origin", F.lit("df2"))
    

    Note: we keep track of the origin of a DataFrame so that when we unpivot, we don't accidentally pull data from two different origins for the same row into the same row. i.e. row 0 for df1 and row 0 for df2 should show up as two separate rows in the output. To do so, we keep track of a row's origin

    Construct the Index

    We first want to keep track of the row number for our datasets so we keep corresponding cells together.

    Note: Do NOT use a 'naked window' (no partition column) i.e. Window.partitionBy() call here. Doing so will result in a single task excecuting the row_number calculation for the ENTIRE DataFrame. If your DataFrame is large, this will result in an OOM and your build not scaling well. We choose here instead to use the lower-level function zipWithIndex, which while less elegant, will scale much better.

    def create_index(df):
        schema_in = df.columns
        indexed = df.rdd.zipWithIndex().toDF().select(
            F.col("_2").alias("index"),
            *[
                F.col("_1").getItem(x).alias(x) for x in schema_in
            ]
        )
        return indexed
    
    
    indexed1 = create_index(df1)
    indexed1.show()
    """
    +-----+-------------+-----+------+-----+------+
    |index|col_1_alias_1|col_2| col_3|col_4|origin|
    +-----+-------------+-----+------+-----+------+
    |    0|        key_1|    1|CREATE|    0|   df1|
    |    1|        key_2|    2|CREATE|    0|   df1|
    |    2|        key_3|    3|CREATE|    0|   df1|
    |    3|        key_1|    1|UPDATE|    1|   df1|
    |    4|        key_2|    2|UPDATE|    1|   df1|
    |    5|        key_1|    1|DELETE|    2|   df1|
    +-----+-------------+-----+------+-----+------+
    """
    
    indexed2 = create_index(df2)
    indexed2.show()
    """
    +-----+-------------+-----+------+-----+------+
    |index|col_1_alias_2|col_2| col_3|col_4|origin|
    +-----+-------------+-----+------+-----+------+
    |    0|        key_1|    1|CREATE|    0|   df2|
    |    1|        key_2|    2|CREATE|    0|   df2|
    |    2|        key_3|    3|CREATE|    0|   df2|
    |    3|        key_1|    1|UPDATE|    1|   df2|
    |    4|        key_2|    2|UPDATE|    1|   df2|
    |    5|        key_1|    1|DELETE|    2|   df2|
    +-----+-------------+-----+------+-----+------+
    """
    

    Melt DataFrames

    We can create another dataset using the melt technique. Conveniently, the transforms.verbs bundle (imported by default into your repositories) has a utility method transforms.verbs.unpivot that has an equivalent implementation you can use.

    from transforms.verbs import unpivot
    
    unpivoted1 = unpivot(indexed1, id_vars=["index", "origin"], value_vars=indexed1.columns)
    unpivoted1.show()
    """
    +-----+------+-------------+------+
    |index|origin|     variable| value|
    +-----+------+-------------+------+
    |    0|   df1|        index|     0|
    |    0|   df1|col_1_alias_1| key_1|
    |    0|   df1|        col_2|     1|
    |    0|   df1|        col_3|CREATE|
    |    0|   df1|        col_4|     0|
    |    0|   df1|       origin|   df1|
    |    1|   df1|        index|     1|
    |    1|   df1|col_1_alias_1| key_2|
    |    1|   df1|        col_2|     2|
    |    1|   df1|        col_3|CREATE|
    |    1|   df1|        col_4|     0|
    |    1|   df1|       origin|   df1|
    |    2|   df1|        index|     2|
    |    2|   df1|col_1_alias_1| key_3|
    |    2|   df1|        col_2|     3|
    |    2|   df1|        col_3|CREATE|
    |    2|   df1|        col_4|     0|
    |    2|   df1|       origin|   df1|
    |    3|   df1|        index|     3|
    |    3|   df1|col_1_alias_1| key_1|
    +-----+------+-------------+------+
    ...
    """
    unpivoted2 = unpivot(indexed2, id_vars=["index", "origin"], value_vars=indexed2.columns)
    unpivoted2.show()
    """
    +-----+------+-------------+------+
    |index|origin|     variable| value|
    +-----+------+-------------+------+
    |    0|   df2|        index|     0|
    |    0|   df2|col_1_alias_2| key_1|
    |    0|   df2|        col_2|     1|
    |    0|   df2|        col_3|CREATE|
    |    0|   df2|        col_4|     0|
    |    0|   df2|       origin|   df2|
    |    1|   df2|        index|     1|
    |    1|   df2|col_1_alias_2| key_2|
    |    1|   df2|        col_2|     2|
    |    1|   df2|        col_3|CREATE|
    |    1|   df2|        col_4|     0|
    |    1|   df2|       origin|   df2|
    |    2|   df2|        index|     2|
    |    2|   df2|col_1_alias_2| key_3|
    |    2|   df2|        col_2|     3|
    |    2|   df2|        col_3|CREATE|
    |    2|   df2|        col_4|     0|
    |    2|   df2|       origin|   df2|
    |    3|   df2|        index|     3|
    |    3|   df2|col_1_alias_2| key_1|
    +-----+------+-------------+------+
    ...
    """
    

    Union DataFrames

    This part is simple enough, simply stack your DataFrames on top of each other. If you have many of them, use transforms.verbs.dataframes.union_many

    all_dfs = unpivoted1.unionByName(unpivoted2)
    

    Alias Columns

    This next section is a bit dense.

    We want to change the values of rows annotating columns that mean the same thing, so we want to build up a case statement that is used to substitute aliases for their final name.

    To do so, we want one big case statement, which is done by stacking .when() statements on top of each other.

    We could choose to make a None variable outside the for() loop and detect the first iteration + directly assign, but in this case we can simply use the pyspark.sql.functions import directly as the first 'when' statement. This lets us put when calls on top of each other, followed by a final 'otherwise' where we use the column name in-place. This lets us efficiently run over every row and figure out if the column it contains is an alias that needs to be renamed.

    alias_dictionary = {
        "col_1": ["col_1_alias_1", "col_1_alias_2"]
    }
    
    when_statement = F
    
    for alias_key, alias_list in alias_dictionary.items():
        # For each alias, if the row we are on is a cell that needs to be aliased,
        #   i.e. it isin() a list, then we rename it to the alias.  Otherwise we 
        #   leave it alone and will use the final .otherwise() outside the loop
        when_statement = when_statement.when(
            F.col("variable").isin(alias_list), alias_key
        )
    when_statement = when_statement.otherwise(F.col("variable"))
    
    # Replace the column names with their aliases
    all_dfs = all_dfs.withColumn("variable", when_statement)
    all_dfs.show()
    
    """
    +-----+------+--------+------+
    |index|origin|variable| value|
    +-----+------+--------+------+
    |    0|   df1|   index|     0|
    |    0|   df1|   col_1| key_1|
    |    0|   df1|   col_2|     1|
    |    0|   df1|   col_3|CREATE|
    |    0|   df1|   col_4|     0|
    |    0|   df1|  origin|   df1|
    |    1|   df1|   index|     1|
    |    1|   df1|   col_1| key_2|
    |    1|   df1|   col_2|     2|
    |    1|   df1|   col_3|CREATE|
    |    1|   df1|   col_4|     0|
    |    1|   df1|  origin|   df1|
    |    2|   df1|   index|     2|
    |    2|   df1|   col_1| key_3|
    |    2|   df1|   col_2|     3|
    |    2|   df1|   col_3|CREATE|
    |    2|   df1|   col_4|     0|
    |    2|   df1|  origin|   df1|
    |    3|   df1|   index|     3|
    |    3|   df1|   col_1| key_1|
    +-----+------+--------+------+
    ...
    """
    

    Pivot to Output

    Finally, we group by each row, pivot back the rows to columns, take the found value, and voila!

    pivoted = all_dfs.groupBy(  # for each original row...
        "index", "origin"
    ).pivot(                    # pivot the rows back to columns...
        "variable"              
    ).agg(                      # don't perform any calculations, just take the cell...
        F.first("value")
    ).orderBy(                  # for printing reasons, order by the original row number...
        "index", "origin"                 
    ).drop(                     # and remove the index
        "index", "origin"                 
    )
    
    pivoted.show()
    """
    +-----+-----+------+-----+
    |col_1|col_2| col_3|col_4|
    +-----+-----+------+-----+
    |key_1|    1|CREATE|    0|
    |key_1|    1|CREATE|    0|
    |key_2|    2|CREATE|    0|
    |key_2|    2|CREATE|    0|
    |key_3|    3|CREATE|    0|
    |key_3|    3|CREATE|    0|
    |key_1|    1|UPDATE|    1|
    |key_1|    1|UPDATE|    1|
    |key_2|    2|UPDATE|    1|
    |key_2|    2|UPDATE|    1|
    |key_1|    1|DELETE|    2|
    |key_1|    1|DELETE|    2|
    +-----+-----+------+-----+
    """