pysparkpalantir-foundrypyspark-transformer

How to replace output's column values with input's values in Foundry?


I have input and output dataframes, both containing a column labeled "Wellname." In Foundry, how can I substitute the values in the "Wellname" column of the output_df with those from the corresponding "Wellname" entries in the input_df? Specifically, I aim to replace the "Wellname" column values in output_df with those found in input_df where the "Wellname" entries match. As an illustration, consider the following output_df before incremental transform:

   Wellname WellType Platform
0       E17 Producer      DWG
1       E17 Producer      DWG
2       E17 Producer      DWG
3      E20Y Producer      DWG
4      E20Y Producer      DWG
5      E20Y Producer      DWG
6      E20Y Producer      DWG
7      E20Y Producer      DWG

And this is the input_df:

   Wellname WellType Platform
0       E17 Producer      CH
1       E17 Producer      CH
2       E17 Producer      CH
3       E21 Producer      DWG
4       E21 Producer      DWG
5       E21 Producer      DWG

As you can see 'E17' is in BOTH input and output dataframes but its 'Platform' values changed in input_df. I want to replace output's 'E17' rows with the input's 'E17' rows plus new added rows of input_df.

After incremental transform, the output_df should be like:

       Wellname WellType Platform
    0       E17 Producer      CH
    1       E17 Producer      CH
    2       E17 Producer      CH
    3      E20Y Producer      DWG
    4      E20Y Producer      DWG
    5      E20Y Producer      DWG
    6      E20Y Producer      DWG
    7      E20Y Producer      DWG
    8      E21 Producer       DWG
    9      E21 Producer       DWG
   10      E21 Producer       DWG

The following is the sample code but I can't achieve what I want.

from transforms.api import transform, incremental, Input, Output, configure
from pyspark.sql import types as T

schema = T.StructType([
    T.StructField('Wellname', T.StringType()),
    T.StructField('WellType', T.StringType()),
    T.StructField('Platform', T.StringType()),
])    

@configure(profile=['KUBERNETES_NO_EXECUTORS'])
@incremental(require_incremental=True, snapshot_inputs=["input_df"])
@transform(
    input_df=Input('ri.foundry.lava-catalog.dataset.7ef47ff2-7015-4dad-aec9-aa3075a63a96'),
    output_df=Output('ri.foundry.lava-catalog.dataset.524cee15-7ac5-410c-96e8-b205bac1cee8')
)
def incremental_filter(input_df, output_df):
    new_df = input_df.dataframe()

    new_df = new_df.unionByName(output_df.dataframe('previous', schema))
    mode = 'modify'

    new_df = new_df.select('Wellname', 'WellType', 'Platform')
    output_df.set_mode(mode)
    output_df.write_dataframe()

Solution

  • One way to filter out outdated rows is to use a join:

    new_df = input_df.dataframe()
    
    prev_df = output_df.dataframe('previous', schema)
    
    prev_df = prev_df.join(new_df, on=['Wellname'], how='left_anti')
    
    new_df = new_df.unionByName(prev_df)