sqlpostgresqlsearchsql-execution-planword-embedding

Efficient many-to-many embedding comparisons


I am trying to recommend a user the top "articles" given embeddings of "interests" they have.

Each "user" will have 5-10 embeddings associated with their profile, represented as arrays of doubles.

Each "article" will also have 5-10 embeddings associated with it (each embedding represents a distinct topic).

I want to write a PostgreSQL query that returns the top 20 "articles" that are most aligned to a users interests. Since each user 5-10 embeddings representing their interests and each article has 5-10 embeddings representing the content it covers, I can't trivially apply an extension like pgvector to solve this issue.

I wrote an algorithm in SQL where I compute pairwise similarities between user embeddings and article embeddings and then take a max along each row and then average those values. It helps to imagine a UxT matrix (where U represents number of user embeddings and T is article embeddings) and fill each entry by the cosine similarity between the corresponding user embedding and articles embedding.

I wrote helper functions to compute products between two arrays, another to compute cosine similarity, and a third to compute "vectors_similarity" -- which computes the similarity between a set of user vectors and a set of article vectors.

The query itself applies a few joins to get the required information, filters out for articles in the last ten days and articles that have already been "read" by the user, and returns the top 20 most similar articles using this methodology.

This takes over 30sec to search through 1000 articles. I am NOT a SQL expert, and am struggling to debug this. Below, I have posted my SQL query and the results of "explain analysis.

Is this just computationally intractable, or am I missing some obvious optimization opportunities?

CREATE OR REPLACE FUNCTION array_product(arr double precision[])
RETURNS double precision AS
$$
DECLARE
    result double precision := 1;
    i integer;
BEGIN
    FOR i IN array_lower(arr, 1) .. array_upper(arr, 1)
    LOOP
        result := result * arr[i];
    END LOOP;
    RETURN result;
END;
$$
LANGUAGE plpgsql;

CREATE OR REPLACE FUNCTION cosine_similarity(a double precision[], b double precision[])
RETURNS double precision AS $$
DECLARE
    dot_product double precision;
    norm_a double precision;
    norm_b double precision;
    a_length int;
    b_length int;
BEGIN
    a_length := array_length(a, 1);
    b_length := array_length(b, 1);

    dot_product := 0;
    norm_a := 0;
    norm_b := 0;
    
    FOR i IN 1..a_length LOOP
        dot_product := dot_product + a[i] * b[i];
        norm_a := norm_a + a[i] * a[i];
        norm_b := norm_b + b[i] * b[i];
    END LOOP;

    norm_a := sqrt(norm_a);
    norm_b := sqrt(norm_b);

    IF norm_a = 0 OR norm_b = 0 THEN
        RETURN 0;
    ELSE
        RETURN dot_product / (norm_a * norm_b);
    END IF;
END;
$$ LANGUAGE plpgsql;

CREATE OR REPLACE FUNCTION vectors_similarity(
    user_vectors FLOAT[][],
    article_vectors FLOAT[][]
) RETURNS FLOAT AS $$
DECLARE
    num_user_vectors INT;
    num_article_vectors INT;
    scores FLOAT[][];
    row_weights FLOAT[];
    row_values FLOAT[];
    col_weights FLOAT[];
    similarity FLOAT;
    article_vector FLOAT[][];
    user_vector FLOAT[][];
    i int;
    j int;
BEGIN
    num_user_vectors := array_length(user_vectors, 1);
    num_article_vectors := array_length(article_vectors, 1);

    scores := ARRAY(SELECT ARRAY(SELECT 0.0 FROM generate_series(1, num_article_vectors)) FROM generate_series(1, num_user_vectors));
    
    i := 1;
    FOREACH user_vector SLICE 1 IN ARRAY user_vectors
    LOOP
        j := 1;
        FOREACH article_vector SLICE 1 IN ARRAY article_vectors
        LOOP
            scores[i][j] := cosine_similarity(user_vector, article_vector);
            scores[i][j] := exp(scores[i][j] * 7);
        j := j+1;
        END LOOP;
    i := i + 1;
    END LOOP;
        
    SELECT 
      AVG(
        (SELECT MAX(row_val) FROM unnest(row_array) AS row_val)
      ) INTO similarity
    FROM 
      (
        SELECT scores[row_index][:] AS row_array
        FROM generate_series(1, array_length(scores, 1)) AS row_index
      ) AS subquery;
    
    RETURN similarity;
END;
$$ LANGUAGE plpgsql;

EXPLAIN ANALYZE
SELECT
        ART.*,
        vectors_similarity(array_agg(TOPIC.vector), ARRAY[ARRAY[ -0.0026961329858750105,0.004657252691686153, -0.011298391036689281, ...], ARRAY[...]]) AS similatory_score
    FROM
        article ART     
    JOIN
        article_topic ART_TOP ON ART.id = ART_TOP.article_id
    JOIN
        topic TOPIC ON ART_TOP.topic_id = TOPIC.id
    WHERE
        ART.date_published > CURRENT_DATE - INTERVAL '5' DAY
        AND NOT EXISTS (
        SELECT 1
        FROM user_article_read USR_ART_READ
        WHERE USR_ART_READ.article_id = ART.id
        AND USR_ART_READ.profile_id = 1 -- :user_id to be inputted by the user with the actual user_id
        )
    GROUP BY
        ART.id
    ORDER BY
        similatory_score DESC, ART.date_published DESC, ART.id DESC
    LIMIT 20;

And here is the analysis:

"Limit  (cost=945.53..945.55 rows=5 width=518) (actual time=27873.197..27873.227 rows=5 loops=1)"
"  Output: art.id, art.created_timestamp, art.modified_timestamp, art.title, art.url, art.date_published, art.summary, art.author, art.thumbnail_url, art.thumbnail_credit, art.source_name, art.feed_type, art.source_type, (scaled_geometric_similarity_vectors(array_agg(topic.vector), '{{-0.0026961329858750105,...}, {...}, ...}'::double precision[])"
"              Group Key: art.id"
"              Batches: 25  Memory Usage: 8524kB  Disk Usage: 3400kB"
"              Buffers: shared hit=14535 read=19, temp read=401 written=750"
"              ->  Hash Join  (cost=395.19..687.79 rows=4491 width=528) (actual time=6.746..20.875 rows=4638 loops=1)"
"                    Output: art.id, art.created_timestamp, art.modified_timestamp, art.title, art.url, art.date_published, art.summary, art.author, art.thumbnail_url, art.thumbnail_credit, art.source_name, art.feed_type, art.source_type, topic.vector"
"                    Inner Unique: true"
"                    Hash Cond: (art_top.topic_id = topic.id)"
"                    Buffers: shared hit=289"
"                    ->  Hash Anti Join  (cost=202.53..483.33 rows=4491 width=518) (actual time=3.190..15.589 rows=4638 loops=1)"
"                          Output: art.id, art.created_timestamp, art.modified_timestamp, art.title, art.url, art.date_published, art.summary, art.author, art.thumbnail_url, art.thumbnail_credit, art.source_name, art.feed_type, art.source_type, art_top.topic_id"
"                          Hash Cond: (art.id = usr_art_read.article_id)"
"                          Buffers: shared hit=229"
"                          ->  Hash Join  (cost=188.09..412.13 rows=4506 width=518) (actual time=3.106..14.853 rows=4638 loops=1)"
"                                Output: art.id, art.created_timestamp, art.modified_timestamp, art.title, art.url, art.date_published, art.summary, art.author, art.thumbnail_url, art.thumbnail_credit, art.source_name, art.feed_type, art.source_type, art_top.topic_id"
"                                Inner Unique: true"
"                                Hash Cond: (art_top.article_id = art.id)"
"                                Buffers: shared hit=224"
"                                ->  Seq Scan on public.article_topic art_top  (cost=0.00..194.67 rows=11167 width=16) (actual time=0.018..7.589 rows=11178 loops=1)"
"                                      Output: art_top.id, art_top.created_timestamp, art_top.article_id, art_top.topic_id"
"                                      Buffers: shared hit=83"
"                                ->  Hash  (cost=177.56..177.56 rows=843 width=510) (actual time=3.005..3.011 rows=818 loops=1)"
"                                      Output: art.id, art.created_timestamp, art.modified_timestamp, art.title, art.url, art.date_published, art.summary, art.author, art.thumbnail_url, art.thumbnail_credit, art.source_name, art.feed_type, art.source_type"
"                                      Buckets: 1024  Batches: 1  Memory Usage: 433kB"
"                                      Buffers: shared hit=141"
"                                      ->  Seq Scan on public.article art  (cost=0.00..177.56 rows=843 width=510) (actual time=0.082..1.585 rows=818 loops=1)"
"                                            Output: art.id, art.created_timestamp, art.modified_timestamp, art.title, art.url, art.date_published, art.summary, art.author, art.thumbnail_url, art.thumbnail_credit, art.source_name, art.feed_type, art.source_type"
"                                            Filter: (art.date_published > (CURRENT_DATE - '5 days'::interval day))"
"                                            Rows Removed by Filter: 1191"
"                                            Buffers: shared hit=141"
"                          ->  Hash  (cost=14.35..14.35 rows=7 width=8) (actual time=0.052..0.052 rows=0 loops=1)"
"                                Output: usr_art_read.article_id"
"                                Buckets: 1024  Batches: 1  Memory Usage: 8kB"
"                                Buffers: shared hit=5"
"                                ->  Bitmap Heap Scan on public.user_article_read usr_art_read  (cost=4.21..14.35 rows=7 width=8) (actual time=0.051..0.052 rows=0 loops=1)"
"                                      Output: usr_art_read.article_id"
"                                      Recheck Cond: (usr_art_read.profile_id = 1)"
"                                      Buffers: shared hit=5"
"                                      ->  Bitmap Index Scan on user_article_read_profile_id_d4edd4f6  (cost=0.00..4.21 rows=7 width=0) (actual time=0.050..0.050 rows=0 loops=1)"
"                                            Index Cond: (usr_art_read.profile_id = 1)"
"                                            Buffers: shared hit=5"
"                    ->  Hash  (cost=118.96..118.96 rows=5896 width=26) (actual time=3.436..3.440 rows=5918 loops=1)"
"                          Output: topic.vector, topic.id"
"                          Buckets: 8192  Batches: 1  Memory Usage: 434kB"
"                          Buffers: shared hit=60"
"                          ->  Seq Scan on public.topic  (cost=0.00..118.96 rows=5896 width=26) (actual time=0.009..2.100 rows=5918 loops=1)"
"                                Output: topic.vector, topic.id"
"                                Buffers: shared hit=60"
"Planning:"
"  Buffers: shared hit=406 read=7"
"Planning Time: 52.507 ms"
"Execution Time: 27875.522 ms"

Solution

  • Currently, almost all cost is accrued in the outer SELECT, and that is not due to sorting. It's the hugely expensive function vectors_similarity() which calls the nested function cosine_similarity() many times, and that nested function is as inefficient as the first.

    (You also show the function array_product(), but that's unused in the query, so just a distraction. Also inefficient, btw.)

    This part in your query plan indicates you need more work_mem:

    Memory Usage: 8,524kB Disk Usage: 3,400kB
    

    Indeed, your server seems to be at default settings, or else EXPLAIN(ANALYZE, VERBOSE, BUFFERS, SETTINGS) (like you claim to have used) would report custom settings. That won't do for a non-trivial workload.

    I started out with "currently", because this is only getting worse. You filter the past 5 days or your data, but don't have an index on article.date_published. Currently, almost half of your 2000 articles qualify, but that ratio is bound to change dramatically. Then you need an index on article (date_published).

    Also, your LIMIT is combined with an ORDER BY on the computed similarity. So there is no way around computing the similarity score for all qualifying rows. A small limit barely helps.
    (Aside: your query plan reports rows=5 though there are more than enough candidate rows, which disagrees with LIMIT 20 in your query.)

    Your course of action should be:

    1.) Optimize both "similarity" functions, most importantly cosine_similarity().

    2.) Reduce the number of rows for which to do the expensive calculation, maybe by pre-filtering rows with a much cheaper filter.

    3.) Optimize server configuration.

    4.) Optimize indexes.