I am trying to recommend a user the top "articles" given embeddings of "interests" they have.
Each "user" will have 5-10 embeddings associated with their profile, represented as arrays of doubles.
Each "article" will also have 5-10 embeddings associated with it (each embedding represents a distinct topic).
I want to write a PostgreSQL query that returns the top 20 "articles" that are most aligned to a users interests. Since each user 5-10 embeddings representing their interests and each article has 5-10 embeddings representing the content it covers, I can't trivially apply an extension like pgvector
to solve this issue.
I wrote an algorithm in SQL where I compute pairwise similarities between user embeddings and article embeddings and then take a max along each row and then average those values. It helps to imagine a UxT matrix (where U represents number of user embeddings and T is article embeddings) and fill each entry by the cosine similarity between the corresponding user embedding and articles embedding.
I wrote helper functions to compute products between two arrays, another to compute cosine similarity, and a third to compute "vectors_similarity" -- which computes the similarity between a set of user vectors and a set of article vectors.
The query itself applies a few joins to get the required information, filters out for articles in the last ten days and articles that have already been "read" by the user, and returns the top 20 most similar articles using this methodology.
This takes over 30sec to search through 1000 articles. I am NOT a SQL expert, and am struggling to debug this. Below, I have posted my SQL query and the results of "explain analysis.
Is this just computationally intractable, or am I missing some obvious optimization opportunities?
CREATE OR REPLACE FUNCTION array_product(arr double precision[])
RETURNS double precision AS
$$
DECLARE
result double precision := 1;
i integer;
BEGIN
FOR i IN array_lower(arr, 1) .. array_upper(arr, 1)
LOOP
result := result * arr[i];
END LOOP;
RETURN result;
END;
$$
LANGUAGE plpgsql;
CREATE OR REPLACE FUNCTION cosine_similarity(a double precision[], b double precision[])
RETURNS double precision AS $$
DECLARE
dot_product double precision;
norm_a double precision;
norm_b double precision;
a_length int;
b_length int;
BEGIN
a_length := array_length(a, 1);
b_length := array_length(b, 1);
dot_product := 0;
norm_a := 0;
norm_b := 0;
FOR i IN 1..a_length LOOP
dot_product := dot_product + a[i] * b[i];
norm_a := norm_a + a[i] * a[i];
norm_b := norm_b + b[i] * b[i];
END LOOP;
norm_a := sqrt(norm_a);
norm_b := sqrt(norm_b);
IF norm_a = 0 OR norm_b = 0 THEN
RETURN 0;
ELSE
RETURN dot_product / (norm_a * norm_b);
END IF;
END;
$$ LANGUAGE plpgsql;
CREATE OR REPLACE FUNCTION vectors_similarity(
user_vectors FLOAT[][],
article_vectors FLOAT[][]
) RETURNS FLOAT AS $$
DECLARE
num_user_vectors INT;
num_article_vectors INT;
scores FLOAT[][];
row_weights FLOAT[];
row_values FLOAT[];
col_weights FLOAT[];
similarity FLOAT;
article_vector FLOAT[][];
user_vector FLOAT[][];
i int;
j int;
BEGIN
num_user_vectors := array_length(user_vectors, 1);
num_article_vectors := array_length(article_vectors, 1);
scores := ARRAY(SELECT ARRAY(SELECT 0.0 FROM generate_series(1, num_article_vectors)) FROM generate_series(1, num_user_vectors));
i := 1;
FOREACH user_vector SLICE 1 IN ARRAY user_vectors
LOOP
j := 1;
FOREACH article_vector SLICE 1 IN ARRAY article_vectors
LOOP
scores[i][j] := cosine_similarity(user_vector, article_vector);
scores[i][j] := exp(scores[i][j] * 7);
j := j+1;
END LOOP;
i := i + 1;
END LOOP;
SELECT
AVG(
(SELECT MAX(row_val) FROM unnest(row_array) AS row_val)
) INTO similarity
FROM
(
SELECT scores[row_index][:] AS row_array
FROM generate_series(1, array_length(scores, 1)) AS row_index
) AS subquery;
RETURN similarity;
END;
$$ LANGUAGE plpgsql;
EXPLAIN ANALYZE
SELECT
ART.*,
vectors_similarity(array_agg(TOPIC.vector), ARRAY[ARRAY[ -0.0026961329858750105,0.004657252691686153, -0.011298391036689281, ...], ARRAY[...]]) AS similatory_score
FROM
article ART
JOIN
article_topic ART_TOP ON ART.id = ART_TOP.article_id
JOIN
topic TOPIC ON ART_TOP.topic_id = TOPIC.id
WHERE
ART.date_published > CURRENT_DATE - INTERVAL '5' DAY
AND NOT EXISTS (
SELECT 1
FROM user_article_read USR_ART_READ
WHERE USR_ART_READ.article_id = ART.id
AND USR_ART_READ.profile_id = 1 -- :user_id to be inputted by the user with the actual user_id
)
GROUP BY
ART.id
ORDER BY
similatory_score DESC, ART.date_published DESC, ART.id DESC
LIMIT 20;
And here is the analysis:
"Limit (cost=945.53..945.55 rows=5 width=518) (actual time=27873.197..27873.227 rows=5 loops=1)"
" Output: art.id, art.created_timestamp, art.modified_timestamp, art.title, art.url, art.date_published, art.summary, art.author, art.thumbnail_url, art.thumbnail_credit, art.source_name, art.feed_type, art.source_type, (scaled_geometric_similarity_vectors(array_agg(topic.vector), '{{-0.0026961329858750105,...}, {...}, ...}'::double precision[])"
" Group Key: art.id"
" Batches: 25 Memory Usage: 8524kB Disk Usage: 3400kB"
" Buffers: shared hit=14535 read=19, temp read=401 written=750"
" -> Hash Join (cost=395.19..687.79 rows=4491 width=528) (actual time=6.746..20.875 rows=4638 loops=1)"
" Output: art.id, art.created_timestamp, art.modified_timestamp, art.title, art.url, art.date_published, art.summary, art.author, art.thumbnail_url, art.thumbnail_credit, art.source_name, art.feed_type, art.source_type, topic.vector"
" Inner Unique: true"
" Hash Cond: (art_top.topic_id = topic.id)"
" Buffers: shared hit=289"
" -> Hash Anti Join (cost=202.53..483.33 rows=4491 width=518) (actual time=3.190..15.589 rows=4638 loops=1)"
" Output: art.id, art.created_timestamp, art.modified_timestamp, art.title, art.url, art.date_published, art.summary, art.author, art.thumbnail_url, art.thumbnail_credit, art.source_name, art.feed_type, art.source_type, art_top.topic_id"
" Hash Cond: (art.id = usr_art_read.article_id)"
" Buffers: shared hit=229"
" -> Hash Join (cost=188.09..412.13 rows=4506 width=518) (actual time=3.106..14.853 rows=4638 loops=1)"
" Output: art.id, art.created_timestamp, art.modified_timestamp, art.title, art.url, art.date_published, art.summary, art.author, art.thumbnail_url, art.thumbnail_credit, art.source_name, art.feed_type, art.source_type, art_top.topic_id"
" Inner Unique: true"
" Hash Cond: (art_top.article_id = art.id)"
" Buffers: shared hit=224"
" -> Seq Scan on public.article_topic art_top (cost=0.00..194.67 rows=11167 width=16) (actual time=0.018..7.589 rows=11178 loops=1)"
" Output: art_top.id, art_top.created_timestamp, art_top.article_id, art_top.topic_id"
" Buffers: shared hit=83"
" -> Hash (cost=177.56..177.56 rows=843 width=510) (actual time=3.005..3.011 rows=818 loops=1)"
" Output: art.id, art.created_timestamp, art.modified_timestamp, art.title, art.url, art.date_published, art.summary, art.author, art.thumbnail_url, art.thumbnail_credit, art.source_name, art.feed_type, art.source_type"
" Buckets: 1024 Batches: 1 Memory Usage: 433kB"
" Buffers: shared hit=141"
" -> Seq Scan on public.article art (cost=0.00..177.56 rows=843 width=510) (actual time=0.082..1.585 rows=818 loops=1)"
" Output: art.id, art.created_timestamp, art.modified_timestamp, art.title, art.url, art.date_published, art.summary, art.author, art.thumbnail_url, art.thumbnail_credit, art.source_name, art.feed_type, art.source_type"
" Filter: (art.date_published > (CURRENT_DATE - '5 days'::interval day))"
" Rows Removed by Filter: 1191"
" Buffers: shared hit=141"
" -> Hash (cost=14.35..14.35 rows=7 width=8) (actual time=0.052..0.052 rows=0 loops=1)"
" Output: usr_art_read.article_id"
" Buckets: 1024 Batches: 1 Memory Usage: 8kB"
" Buffers: shared hit=5"
" -> Bitmap Heap Scan on public.user_article_read usr_art_read (cost=4.21..14.35 rows=7 width=8) (actual time=0.051..0.052 rows=0 loops=1)"
" Output: usr_art_read.article_id"
" Recheck Cond: (usr_art_read.profile_id = 1)"
" Buffers: shared hit=5"
" -> Bitmap Index Scan on user_article_read_profile_id_d4edd4f6 (cost=0.00..4.21 rows=7 width=0) (actual time=0.050..0.050 rows=0 loops=1)"
" Index Cond: (usr_art_read.profile_id = 1)"
" Buffers: shared hit=5"
" -> Hash (cost=118.96..118.96 rows=5896 width=26) (actual time=3.436..3.440 rows=5918 loops=1)"
" Output: topic.vector, topic.id"
" Buckets: 8192 Batches: 1 Memory Usage: 434kB"
" Buffers: shared hit=60"
" -> Seq Scan on public.topic (cost=0.00..118.96 rows=5896 width=26) (actual time=0.009..2.100 rows=5918 loops=1)"
" Output: topic.vector, topic.id"
" Buffers: shared hit=60"
"Planning:"
" Buffers: shared hit=406 read=7"
"Planning Time: 52.507 ms"
"Execution Time: 27875.522 ms"
Currently, almost all cost is accrued in the outer SELECT
, and that is not due to sorting. It's the hugely expensive function vectors_similarity()
which calls the nested function cosine_similarity()
many times, and that nested function is as inefficient as the first.
(You also show the function array_product()
, but that's unused in the query, so just a distraction. Also inefficient, btw.)
This part in your query plan indicates you need more work_mem
:
Memory Usage: 8,524kB Disk Usage: 3,400kB
Indeed, your server seems to be at default settings, or else EXPLAIN(ANALYZE, VERBOSE, BUFFERS, SETTINGS)
(like you claim to have used) would report custom settings. That won't do for a non-trivial workload.
I started out with "currently", because this is only getting worse. You filter the past 5 days or your data, but don't have an index on article.date_published
. Currently, almost half of your 2000 articles qualify, but that ratio is bound to change dramatically. Then you need an index on article (date_published)
.
Also, your LIMIT
is combined with an ORDER BY
on the computed similarity. So there is no way around computing the similarity score for all qualifying rows. A small limit barely helps.
(Aside: your query plan reports rows=5
though there are more than enough candidate rows, which disagrees with LIMIT 20
in your query.)
Your course of action should be:
1.) Optimize both "similarity" functions, most importantly cosine_similarity()
.
2.) Reduce the number of rows for which to do the expensive calculation, maybe by pre-filtering rows with a much cheaper filter.
3.) Optimize server configuration.
4.) Optimize indexes.