I'd like to use Persistent/Esqueleto to implement count estimates.
One approach recommended in this article is to define a function like this
CREATE FUNCTION count_estimate(query text) RETURNS integer AS $$
DECLARE
rec record;
rows integer;
BEGIN
FOR rec IN EXECUTE 'EXPLAIN ' || query LOOP
rows := substring(rec."QUERY PLAN" FROM ' rows=([[:digit:]]+)');
EXIT WHEN rows IS NOT NULL;
END LOOP;
RETURN rows;
END;
$$ LANGUAGE plpgsql VOLATILE STRICT;
and then use it like this
SELECT count_estimate('SELECT * FROM companies WHERE status = ''Active''');
In order to use the count_estimate
function, I'll need (I think?) to render the query that Peristent/Equeleto generates, however when I try rendering the query with renderQuerySelect
, I get something like this
SELECT "companies"."id", "companies"."name", "companies"."status"
FROM "companies"
WHERE "companies"."status" IN (?)
; [PersistText "Active"]
This of course can't be stuffed into the count_estimate
, because it will syntax error on the ?
placeholder. I also can't naïvely replace the ?
with "Active"
, because it will syntax error on that first double quote.
How do I render the query in a way that my count_estimate
function will accept?
I tried something like this, but it fails at runtime
getEstimate :: (Text, [PersistValue]) -> DB [Single Int]
getEstimate (query, params) = rawSql [st|
SELECT count_estimate('#{query}');
|] params
I managed to figure it out (mostly).
It's a matter of escaping the single quotes in both the query and the PersistValue
parameters. I'm doing it like this at the moment, but escaping will need to be added back in otherwise I think it creates a SQL injection vulnerability. I may also need to handle the other PersistValue
constructors in some specific way, but I haven't run into problems there yet.
import qualified Data.Text as T
import qualified Database.Persist as P
getEstimate :: (Text, [PersistValue]) -> DB (Maybe Int)
getEstimate (query, params) = fmap unSingle . listToMaybe <$> rawSql [st|
SELECT count_estimate('#{T.replace "'" "''" query}');
|] (map replace' params)
where literal a = PersistLiteral_ P.Unescaped ("''" <> a <> "''")
replace' = \case
PersistText t -> literal $ encodeUtf8 t
PersistDay d -> literal $ encodeUtf8 $ pack $ showGregorian d
a -> a