pythonuser-defined-functionspython-polars

Optimize computation of similarity scores by executing native polars command instead of UDF functions


Disclaimer (1): This question is supportive to this SO. After a request from two users to elaborate on my case.

Disclaimer (2) - added 29/11: I have seen two solutions so far (proposed in this SO and the supportive one), that utilize the explode() functionality. Based on some benchmarks I did on the whole (~3m rows data) the RAM literally explodes, thus I will test the function on a sample of the dataset and if it works I will accept the solutions of explode() method for those who might experiment on smaller tables.

The input dataset (~3m rows) is the ratings.csv from the ml-latest dataset of 80_000 IMDb movies and respective ratings from 330_000 users (you may download the CSV file from here - 891mb).

I load the dataset using polars like movie_ratings = pl.read_csv(os.path.join(application_path + data_directory, "ratings.csv")), application_path and data_directory is a parent path on my local server.

Having read the dataset my goal is to generate the cosine similarity of a user between all the other users. To do so, first I have to transform the ratings table (~3m rows) to a table with 1 row per user. Thus, I run the following query

## 1st computation bottleneck using UDF functions (2.5minutes for 250_000 rows)
users_metadata = movie_ratings.filter(
        (pl.col("userId") != input_id) #input_id is a random userId. I prefer to make my tests using userId '1' so input_id=1 in this case.
    ).group_by("userId")\
        .agg(
            pl.col("movieId").unique().alias("user_movies"),
            pl.col("rating").alias("user_ratings")
        )\
        .with_columns(
            pl.col("user_movies").map_elements(
                lambda row: sorted( list(set(row).intersection(set(user_rated_movies))) ), return_dtype=pl.List(pl.Int64)
            ).alias("common_movies")
        )\
        .with_columns(
            pl.col("common_movies").map_elements(
                lambda row: len(row), return_dtype=pl.Int64
            ).alias("common_movies_frequency")
        )
similar_users = (
    users_metadata.filter(
       (pl.col("common_movies_frequency").le(len(user_rated_movies))) &
       (pl.col("common_movies_frequency").gt(0)) # we don't want the users that don't have seen any movies from the ones seen/rated by the target user.
    )
    .sort("common_movies_frequency", descending=True)
)

## 2nd computation bottleneck using UDF functions
similar_users = (
    similar_users.with_columns(
        pl.struct(pl.all()).map_elements(
            get_common_movie_ratings, #asked on StackOverflow
            return_dtype=pl.List(pl.Float64),
            strategy="threading"
        ).alias("common_movie_ratings")
    ).with_columns(
        pl.struct(["common_movies"]).map_elements(
            lambda row: get_target_movie_ratings(row, user_rated_movies, user_ratings),
            return_dtype=pl.List(pl.Float64),
            strategy="threading"
        ).alias("target_user_common_movie_ratings")
    ).with_columns(
        pl.struct(["common_movie_ratings","target_user_common_movie_ratings"]).map_elements(
             lambda row: compute_cosine(row),
             return_dtype=pl.Float64,
             strategy="threading"
        ).alias("similarity_score")
    )
)

The code snippet above groups the table by userId and computes some important metadata about them. Specifically,

Screenshot of the table (don't give attention to column potential recommendations) enter image description here

Finally, I filter the table users_metadata by all the users with less than or equal common_movies_frequency to the 62 (len(user_rated_movies)) movies seen by user1. Those are a total of 250_000 users.

This table is the input dataframe for the UDF function I asked in this question. Using this dataframe (~250_000 users) I want to calculate the cosine similarity of each user with user 1. To do so, I want to compare their rating similarity. So on the movies commonly rated by each user, compute the cosine similarity among two arrays of ratings.

Below are the three UDF functions I use to support my functionality.

def get_common_movie_ratings(row) -> pl.List(pl.Float64):
    common_movies = row['common_movies']
    user_ratings = row['user_ratings']
    ratings_for_common_movies = [user_ratings[list(row['user_movies']).index(movie)] for movie in common_movies]
    return ratings_for_common_movies

def get_target_movie_ratings(row, target_user_movies:np.ndarray, target_user_ratings:np.ndarray) -> pl.List(pl.Float64):
    common_movies = row['common_movies']
    target_user_common_ratings = [target_user_ratings[list(target_user_movies).index(movie)] for movie in common_movies]
    return target_user_common_ratings

def compute_cosine(row)->pl.Float64:
    array1 = row["common_movie_ratings"]
    array2 = row["target_user_common_movie_ratings"]
    magnitude1 = norm(array1)
    magnitude2 = norm(array2)
    if magnitude1 != 0 or magnitude2 != 0: #avoid division with 0 norms/magnitudes
        score: float = np.dot(array1, array2) / (norm(array1) * norm(array2))
    else:
        score: float = 0.0
    return score

Benchmarks

The main question is how can I transform those 3 UDF functions into native polars commands.

logs from a custom logger I made

2023-11-29 13:40:24 - INFO - Computed potential similar user metadata for 254188 users in: 0:02:15.586497

2023-11-29 13:40:51 - INFO - Computed similarity scores for 194943 users in: 0:00:27.472388

We can conclude that the main bottleneck of the code is when creating the user_metadata table.


Solution

  • CSV

    Parquet

    Example:

    In the hopes of making things simpler for replicating results, I've filtered the dataset pl.col("userId").is_between(1, 3) and removed the timestamp column:

    movie_ratings = pl.read_csv(
        b'userId,movieId,rating\n1,1,4.0\n1,110,4.0\n1,158,4.0\n1,260,4.5\n1,356,5.0\n1,381,3.5\n1,596,4.0\n1,1036,5.0\n1,1049,'
        b'3.0\n1,1066,4.0\n1,1196,3.5\n1,1200,3.5\n1,1210,4.5\n1,1214,4.0\n1,1291,5.0\n1,1293,2.0\n1,1376,3.0\n1,1396,3.0\n1,153'
        b'7,4.0\n1,1909,3.0\n1,1959,4.0\n1,1960,4.0\n1,2028,5.0\n1,2085,3.5\n1,2116,4.0\n1,2336,3.5\n1,2571,2.5\n1,2671,4.0\n1,2'
        b'762,5.0\n1,2804,3.0\n1,2908,4.0\n1,3363,3.0\n1,3578,5.0\n1,4246,4.0\n1,4306,4.0\n1,4699,3.5\n1,4886,5.0\n1,4896,4.0\n1'
        b',4993,4.0\n1,4995,5.0\n1,5952,4.5\n1,6539,4.0\n1,7064,3.5\n1,7122,4.0\n1,7139,3.0\n1,7153,5.0\n1,7162,4.0\n1,7366,3.5'
        b'\n1,7706,3.5\n1,8132,5.0\n1,8533,5.0\n1,8644,3.5\n1,8961,4.5\n1,8969,4.0\n1,8981,3.5\n1,33166,5.0\n1,33794,3.0\n1,40629'
        b',4.5\n1,49647,5.0\n1,52458,5.0\n1,53996,5.0\n1,54259,4.0\n2,1,5.0\n2,2,3.0\n2,6,4.0\n2,10,3.0\n2,11,3.0\n2,17,5.0\n2,1'
        b'9,3.0\n2,21,5.0\n2,25,3.0\n2,31,3.0\n2,34,5.0\n2,36,5.0\n2,39,3.0\n2,47,5.0\n2,48,2.0\n2,50,4.0\n2,52,3.0\n2,58,3.0\n2'
        b',95,2.0\n2,110,5.0\n2,111,3.0\n2,141,5.0\n2,150,5.0\n2,151,5.0\n2,153,3.0\n2,158,3.0\n2,160,1.0\n2,161,3.0\n2,165,4.0'
        b'\n2,168,3.0\n2,172,2.0\n2,173,2.0\n2,185,3.0\n2,186,3.0\n2,204,3.0\n2,208,3.0\n2,224,3.0\n2,225,3.0\n2,231,4.0\n2,235,3'
        b'.0\n2,236,2.0\n2,252,3.0\n2,253,2.0\n2,256,3.0\n2,261,4.0\n2,265,2.0\n2,266,4.0\n2,282,1.0\n2,288,1.0\n2,292,3.0\n2,29'
        b'3,3.0\n2,296,5.0\n2,300,4.0\n2,315,3.0\n2,317,3.0\n2,318,5.0\n2,333,3.0\n2,337,3.0\n2,339,5.0\n2,344,3.0\n2,349,4.0\n2'
        b',350,3.0\n2,356,5.0\n2,357,5.0\n2,364,4.0\n2,367,4.0\n2,377,4.0\n2,380,4.0\n2,420,2.0\n2,432,3.0\n2,434,4.0\n2,440,3.0'
        b'\n2,442,3.0\n2,454,3.0\n2,457,5.0\n2,480,3.0\n2,500,4.0\n2,509,3.0\n2,527,5.0\n2,539,5.0\n2,553,3.0\n2,586,4.0\n2,587,'
        b'4.0\n2,588,4.0\n2,589,4.0\n2,590,5.0\n2,592,3.0\n2,593,5.0\n2,595,4.0\n2,597,5.0\n2,786,4.0\n3,296,5.0\n3,318,5.0\n3,8'
        b'58,5.0\n3,2959,5.0\n3,3114,5.0\n3,3751,5.0\n3,4886,5.0\n3,6377,5.0\n3,8961,5.0\n3,60069,5.0\n3,68954,5.0\n3,69844,5.0'
        b'\n3,74458,5.0\n3,76093,5.0\n3,79132,5.0\n3,81834,5.0\n3,88125,5.0\n3,99114,5.0\n3,109487,5.0\n3,112556,5.0\n3,115617,5.'
        b'0\n3,115713,4.0\n3,116797,5.0\n3,119145,5.0\n3,134853,5.0\n3,152081,5.0\n3,176101,5.0\n3,177765,5.0\n3,185029,5.0\n3,1'
        b'87593,3.0\n'
    )
    

    We will assume input_id == 1

    One possible approach for gathering all the needed information:

    # Finding the intersection first seems to use ~35% less RAM
    # than the previous join / anti-join approach
    intersection = (
       movie_ratings
        .filter(
           (pl.col("userId") == 1)
           | 
           ((pl.col("userId") != 1) &
            (pl.col("movieId").is_in(pl.col("movieId").filter(pl.col("userId") == 1))))
        )
    )
    
    (intersection.filter(pl.col("userId") == 1)
      .join(
         intersection.filter(pl.col("userId") != 1),
         on = "movieId"
      )
      .group_by(pl.col("userId_right").alias("other_user"))
      .agg(
         target_user = pl.first("userId"),
         common_movies = "movieId",
         common_movies_frequency = pl.count(),
         target_user_ratings = "rating",
         other_user_ratings = "rating_right",
      )
    )
    
    shape: (2, 6)
    ┌────────────┬─────────────┬────────────────────┬─────────────────────────┬──────────────────────┬──────────────────────┐
    │ other_user ┆ target_user ┆ common_movies      ┆ common_movies_frequency ┆ target_user_ratings  ┆ other_user_ratings   │
    │ ---        ┆ ---         ┆ ---                ┆ ---                     ┆ ---                  ┆ ---                  │
    │ i64        ┆ i64         ┆ list[i64]          ┆ u32                     ┆ list[f64]            ┆ list[f64]            │
    ╞════════════╪═════════════╪════════════════════╪═════════════════════════╪══════════════════════╪══════════════════════╡
    │ 3          ┆ 1           ┆ [4886, 8961]       ┆ 2                       ┆ [5.0, 4.5]           ┆ [5.0, 5.0]           │
    │ 2          ┆ 1           ┆ [1, 110, 158, 356] ┆ 4                       ┆ [4.0, 4.0, 4.0, 5.0] ┆ [5.0, 5.0, 3.0, 5.0] │
    └────────────┴─────────────┴────────────────────┴─────────────────────────┴──────────────────────┴──────────────────────┘
    

    Lazy API

    There may be a better strategy to parallelize the work, but a baseline attempt could simply loop through each userID

    movie_ratings = pl.scan_parquet("imdb.parquet")
    
    user_ids = movie_ratings.select(pl.col("userId").unique()).collect().to_series()
    
    for user_id in user_ids:
        result = (
            movie_ratings
                .filter(pl.col("userId") == user_id)
                ...
        )
        print(result.collect())
    

    DuckDB

    I was curious, so decided to check duckdb for a comparison.

    import duckdb
    
    duckdb.sql("""
    with 
       db as (from movie_ratings)
    from 
       db target, db other
    select 
       target.userId        target_user,
       other.userId         other_user,
       list(other.movieId)  common_movies,
       count(other.movieId) common_movies_frequency,
       list(target.rating)  target_user_ratings,
       list(other.rating)   other_user_ratings,
    where 
       target_user = 1 and other_user != 1 and target.movieId = other.movieId
    group by 
       target_user, other_user
    """).pl()
    
    shape: (2, 6)
    ┌─────────────┬────────────┬────────────────────┬─────────────────────────┬──────────────────────┬──────────────────────┐
    │ target_user ┆ other_user ┆ common_movies      ┆ common_movies_frequency ┆ target_user_ratings  ┆ other_user_ratings   │
    │ ---         ┆ ---        ┆ ---                ┆ ---                     ┆ ---                  ┆ ---                  │
    │ i64         ┆ i64        ┆ list[i64]          ┆ i64                     ┆ list[f64]            ┆ list[f64]            │
    ╞═════════════╪════════════╪════════════════════╪═════════════════════════╪══════════════════════╪══════════════════════╡
    │ 1           ┆ 3          ┆ [4886, 8961]       ┆ 2                       ┆ [5.0, 4.5]           ┆ [5.0, 5.0]           │
    │ 1           ┆ 2          ┆ [1, 110, 356, 158] ┆ 4                       ┆ [4.0, 4.0, 5.0, 4.0] ┆ [5.0, 5.0, 5.0, 3.0] │
    └─────────────┴────────────┴────────────────────┴─────────────────────────┴──────────────────────┴──────────────────────┘
    

    RAM Usage

    Running both examples against the full dataset (runtime is basically the same) I get:

    import rich.filesize
    
    print("duckdb:", rich.filesize.decimal(223232000))
    print("polars:", rich.filesize.decimal(1772072960))
    
    duckdb: 223.2 MB
    polars: 1.8 GB
    

    So it seems there is potential room for improvement on the Polars side.