pythongoogle-bigquerypython-polarsrust-polars

How to join two polars dataframe with multiple types of conditions? (Equalities and Comparisons)


I'm trying to translate an SQL query into a polars code, and I'm stuck at one line of query, which basically join two tables with certain conditions, here is the sample SQL query that I have been working on:

SELECT * 
FROM table_1 as tab1
LEFT JOIN table_2 as tab2
ON tab1.article_no = tab2.article_no -- condition 1
   AND DIV(tab1.variant_no, 1000) = tab2.variant_no -- condition 2
   AND tab1.date_of_day BETWEEN tab2.date_from AND tab2.date_to -- condition 3

So here when doing the LEFT JOIN, There are three conditions, out of which two are equality checks and one is checking whether date_of_day lies in between the other two date columns from table_2.

In polars, when you have to join two dataframes/lazyframes you have to use either tab_1.join(tab_2, on=[certain_column]) or tab_1.join(tab_2, left_on=[certain_column], right_on=[certain_column]).

We can't use on= and [left_on, right_on] together, but in this case, I can't put all the three conditions either inside the on= or the left_on, right_on, so I'm not sure how can I possibly achieve this.

Although I tried doing the date comparison separately after joining with the first two conditions:

table_1 = pl.read_csv('table_1.csv')
table_2 = pl.read_csv('table_2.csv')

table_joined = (
    table_1
    .join(
        table_2, 
        how='left', 
        left_on=[pl.col('article_no', pl.col('variant_no')//1000)],  #Giving first two conditions in join
        right_on=[pl.col('article_no'), pl.col('variant_no')], 
        suffix="_tab2"
    )
)

columns_to_check = ['price_1', 'price_2', 'price_3']

final_df = (
    table_joined
    .with_columns(
        [
            pl.when(
                (pl.col('date_of_day').ge(pl.col('date_from'))) &   # giving the third condition after joining, and converting the other columns to None since it's a left join
                (pl.col('date_of_day').le(pl.col('date_to')))
            ).then(pl.col(col)).otherwise(None).alias(col)
            for col in columns_to_check   # using for loop to convert other columns to None (Null)
        ]
    )
)

Although this technique is working fine, but it takes a lot of memory when working with millions of rows, also it's not very consistent way of joining (because sometimes there are more than expected rows in the final dataframe). I tried looking online for ways to give multiple types of conditions, but haven't found any such examples.

I need to apply all the three conditions during the left join itself, to avoid such walkthrough and memory consumptions.

Can somebody help me out here. Thanks in advance.

Edit -

Below is a sample data and code:

import polars as pl
from datetime import date

table_1_data = {
    "article_no": [1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008],
    "variant_no": [5000, 6000, 7000, 8000, 9000, 6000, 4000, 6000],
    "date_of_day": [date(2023, 1, 2), date(2023, 1, 12), date(2023, 1, 15), date(2023, 2, 5), date(2023, 3, 15), date(2023, 1, 2), date(2023, 1, 12), date(2023, 1, 15)],
    "quantity": [10, 15, 20, 25, 30, 45, 50, 60]
}

table_1 = pl.DataFrame(table_1_data)

table_2_data = {
    "article_no": [1001, 1002, 1003, 1004, 1006],
    "variant_no": [5, 6, 7, 8, 9],
    "date_from": [date(2023, 1, 1), date(2023, 1, 10), date(2023, 1, 1), date(2023, 1, 1), date(2023, 3, 1)],
    "date_to": [date(2023, 2, 1), date(2023, 1, 15), date(2023, 1, 10), date(2023, 1, 15), date(2023, 3, 31)],
    "price": [100, 110, 120, 130, 140]
}

table_2 = pl.DataFrame(table_2_data)

# Print sample data
print("Table 1:")
print(table_1)
print("\nTable 2:")
print(table_2)

joined_table = (
    table_1
    .join(table_2, 
          how='left', 
          left_on=[pl.col('article_no'), pl.col('variant_no')//1000],
          right_on=[pl.col('article_no'), pl.col('variant_no')],
          suffix='_tab2'
    )
    .select('article_no', 'variant_no', 'date_of_day', 'quantity', 'price')
)

print('\nJoined Table:')
print(joined_table)

exptected = {
    "article_no": [1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008],
    "variant_no": [5000, 6000, 7000, 8000, 9000, 6000, 4000, 6000],
    "date_of_day": [date(2023, 1, 2), date(2023, 1, 12), date(2023, 1, 15), date(2023, 2, 5), date(2023, 3, 15), date(2023, 1, 2), date(2023, 1, 12), date(2023, 1, 15)],
    "quantity": [10, 15, 20, 25, 30, 45, 50, 60],
    "price": [100, 110, None, None, None, None, None, None]
}

exptected = pl.DataFrame(exptected)

print('\nExpected Output:')
print(exptected)

The output for this code is:

Table 1:
shape: (8, 4)
┌────────────┬────────────┬─────────────┬──────────┐
│ article_no ┆ variant_no ┆ date_of_day ┆ quantity │
│ ---        ┆ ---        ┆ ---         ┆ ---      │
│ i64        ┆ i64        ┆ date        ┆ i64      │
╞════════════╪════════════╪═════════════╪══════════╡
│ 1001       ┆ 5000       ┆ 2023-01-02  ┆ 10       │
│ 1002       ┆ 6000       ┆ 2023-01-12  ┆ 15       │
│ 1003       ┆ 7000       ┆ 2023-01-15  ┆ 20       │
│ 1004       ┆ 8000       ┆ 2023-02-05  ┆ 25       │
│ 1005       ┆ 9000       ┆ 2023-03-15  ┆ 30       │
│ 1006       ┆ 6000       ┆ 2023-01-02  ┆ 45       │
│ 1007       ┆ 4000       ┆ 2023-01-12  ┆ 50       │
│ 1008       ┆ 6000       ┆ 2023-01-15  ┆ 60       │
└────────────┴────────────┴─────────────┴──────────┘

Table 2:
shape: (5, 5)
┌────────────┬────────────┬────────────┬────────────┬───────┐
│ article_no ┆ variant_no ┆ date_from  ┆ date_to    ┆ price │
│ ---        ┆ ---        ┆ ---        ┆ ---        ┆ ---   │
│ i64        ┆ i64        ┆ date       ┆ date       ┆ i64   │
╞════════════╪════════════╪════════════╪════════════╪═══════╡
│ 1001       ┆ 5          ┆ 2023-01-01 ┆ 2023-02-01 ┆ 100   │
│ 1002       ┆ 6          ┆ 2023-01-10 ┆ 2023-01-15 ┆ 110   │
│ 1003       ┆ 7          ┆ 2023-01-01 ┆ 2023-01-10 ┆ 120   │
│ 1004       ┆ 8          ┆ 2023-01-01 ┆ 2023-01-15 ┆ 130   │
│ 1006       ┆ 9          ┆ 2023-03-01 ┆ 2023-03-31 ┆ 140   │
└────────────┴────────────┴────────────┴────────────┴───────┘

Joined Table:
shape: (8, 5)
┌────────────┬────────────┬─────────────┬──────────┬───────┐
│ article_no ┆ variant_no ┆ date_of_day ┆ quantity ┆ price │
│ ---        ┆ ---        ┆ ---         ┆ ---      ┆ ---   │
│ i64        ┆ i64        ┆ date        ┆ i64      ┆ i64   │
╞════════════╪════════════╪═════════════╪══════════╪═══════╡
│ 1001       ┆ 5000       ┆ 2023-01-02  ┆ 10       ┆ 100   │
│ 1002       ┆ 6000       ┆ 2023-01-12  ┆ 15       ┆ 110   │
│ 1003       ┆ 7000       ┆ 2023-01-15  ┆ 20       ┆ 120   │
│ 1004       ┆ 8000       ┆ 2023-02-05  ┆ 25       ┆ 130   │
│ 1005       ┆ 9000       ┆ 2023-03-15  ┆ 30       ┆ null  │
│ 1006       ┆ 6000       ┆ 2023-01-02  ┆ 45       ┆ null  │
│ 1007       ┆ 4000       ┆ 2023-01-12  ┆ 50       ┆ null  │
│ 1008       ┆ 6000       ┆ 2023-01-15  ┆ 60       ┆ null  │
└────────────┴────────────┴─────────────┴──────────┴───────┘

Expected Output:
shape: (8, 5)
┌────────────┬────────────┬─────────────┬──────────┬───────┐
│ article_no ┆ variant_no ┆ date_of_day ┆ quantity ┆ price │
│ ---        ┆ ---        ┆ ---         ┆ ---      ┆ ---   │
│ i64        ┆ i64        ┆ date        ┆ i64      ┆ i64   │
╞════════════╪════════════╪═════════════╪══════════╪═══════╡
│ 1001       ┆ 5000       ┆ 2023-01-02  ┆ 10       ┆ 100   │
│ 1002       ┆ 6000       ┆ 2023-01-12  ┆ 15       ┆ 110   │
│ 1003       ┆ 7000       ┆ 2023-01-15  ┆ 20       ┆ null  │
│ 1004       ┆ 8000       ┆ 2023-02-05  ┆ 25       ┆ null  │
│ 1005       ┆ 9000       ┆ 2023-03-15  ┆ 30       ┆ null  │
│ 1006       ┆ 6000       ┆ 2023-01-02  ┆ 45       ┆ null  │
│ 1007       ┆ 4000       ┆ 2023-01-12  ┆ 50       ┆ null  │
│ 1008       ┆ 6000       ┆ 2023-01-15  ┆ 60       ┆ null  │
└────────────┴────────────┴─────────────┴──────────┴───────┘


here in the output, the row number 3 and 4 doesn't satisfy the third condition that is date_from <= date_of_day <= date_to, so the columns from table_2 should be Null.


Solution

  • In polars, we have join_asof which is very similar to left join.

    result = (
        table_1.lazy()
        .join(
            table_2.lazy(),
            how="left",
            left_on="article_no",
            right_on="article_no"
        )
        .filter(
            (pl.col("variant_no") // 1000 == pl.col("variant_no_right")) &
            (pl.col("date_of_day") >= pl.col("date_from")) &
            (pl.col("date_of_day") <= pl.col("date_to"))
        )
        .with_columns([
            pl.when(pl.col("date_to").is_null())
            .then(None)
            .otherwise(pl.col("price"))
            .alias("price")
        ])
        .collect()
    )
    

    Please refer to the below sample code with sample data,

    import polars as pl
    from datetime import date
    
    
    table_1_data = {
        "article_no": [1001, 1002, 1003, 1004, 1005],
        "variant_no": [5000, 6000, 7000, 8000, 9000],
        "date_of_day": [date(2023, 1, 2), date(2023, 1, 12), date(2023, 1, 15), date(2023, 2, 5), date(2023, 3, 15)],
        "quantity": [10, 15, 20, 25, 30]
    }
    
    table_1 = pl.DataFrame(table_1_data)
    
    table_2_data = {
        "article_no": [1001, 1002, 1003, 1004, 1006],
        "variant_no": [5, 6, 7, 8, 9],
        "date_from": [date(2023, 1, 1), date(2023, 1, 10), date(2023, 1, 1), date(2023, 2, 1), date(2023, 3, 1)],
        "date_to": [date(2023, 2, 1), date(2023, 1, 15), date(2023, 2, 1), date(2023, 2, 15), date(2023, 3, 31)],
        "price": [100, 110, 120, 130, 140]
    }
    
    table_2 = pl.DataFrame(table_2_data)
    
    # Print sample data
    print("Table 1:")
    print(table_1)
    print("\nTable 2:")
    print(table_2)
    
    result = (
        table_1.lazy()
        .join(
            table_2.lazy(),
            how="left",
            left_on="article_no",
            right_on="article_no"
        )
        .filter(
            (pl.col("variant_no") // 1000 == pl.col("variant_no_right")) &
            (pl.col("date_of_day") >= pl.col("date_from")) &
            (pl.col("date_of_day") <= pl.col("date_to"))
        )
        .with_columns([
            pl.when(pl.col("date_to").is_null())
            .then(None)
            .otherwise(pl.col("price"))
            .alias("price")
        ])
        .collect()
    )
    print("\nJoin Result:")
    print(result)
    

    Updated code as per updated expected output in the question,

    joined_table = (
        table_1
        .join(
            table_2,
            how='left',
            left_on=[pl.col('article_no'), pl.col('variant_no')//1000],
            right_on=[pl.col('article_no'), pl.col('variant_no')]
        )
        .with_columns([
            pl.when(
                (pl.col('date_of_day') >= pl.col('date_from')) &
                (pl.col('date_of_day') <= pl.col('date_to'))
            ).then(pl.col('price')).otherwise(None).alias('price')
        ])
        .select('article_no', 'variant_no', 'date_of_day', 'quantity', 'price')
    )