I am starting to look into pyo3, and as a test I am trying to wrap a rust library using pyo3; however, I've been incurring in some performances issues when lambda functions are passed as argument.
Assume I have a rust library with a function that takes a callback as an argument. This function evaluates the callback a number of times and then returns a result, e.g.:
pub fn test_function<F: Fn(f64) -> f64>(cb: F) -> f64 {
//Not actually an implementation this trivial, it is just to execute the callback a number of times
(0..225_000).map(|i| cb(i as f64)).sum::<f64>()
}
I try to measure how long this function takes by executing this function 1000 times by passing a callback and taking the average time.
use std::time::SystemTime;
use pyo3test::test_function;
pub fn main() {
let mut sum = 0.0f64;
let reps = 1_000;
let start = SystemTime::now();
for _ in 0..reps {
sum += test_function(|x| x);
}
let end = start.elapsed().unwrap();
println!("Result: {sum}");
println!("Duration: {:?}", end.checked_div(reps).unwrap());
}
On my machine, when run in release, this takes roughly 250 microseconds per execution of test_function
.
I then tried to wrap this function using pyo3 in the following way
use pyo3::{PyAny, pyfunction, pymodule, PyResult, Python, wrap_pyfunction};
use pyo3::prelude::PyModule;
use crate::test_function;
#[pymodule]
fn pyo3test(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(py_test_function, m)?)?;
Ok(())
}
#[pyfunction]
pub fn py_test_function(function: &PyAny) -> f64 {
assert!(function.is_callable());
let cb = move |x| function.call1((x, )).unwrap().extract::<f64>().unwrap();
test_function(cb)
}
I compile everything (in release), I import the module in python, and then I measure again how much time it takes, by executing this function 1000 times and taking the average time
import pyo3test
from time import time
reps = 1000
cb = lambda x: x
summation = 0
start = time()
for _ in range(reps):
summation += pyo3test.py_test_function(cb)
end = time()
duration = end-start
avg = duration/reps
print(avg)
In this case, the average execution time is roughly 20 milliseconds, almost 2 orders of magnitude more than the pure rust case. I didn't expect to have the same execution time due to the GIL, but I would have guessed something closer to the millisecond.
Is this actually expected or am I missing something? Is it possible to improve on this, possibly without changing the pure-rust implementation?
I tried to look at the documentation, and while it suggests that extract
is slow, I don't think I can do anything different here, since downcast
cannot be used here.
Is there anything else that can be done?
UPDATE
Trying to follow @Ahmed AEK's advice, I wrote another rust function:
pub fn test_function_alt(values: &[f64]) -> f64 {
values.iter().sum::<f64>()
}
I wrapped it using pyo3
#[pyfunction]
pub fn py_test_function_alt(values: Vec<f64>) -> f64 {
test_function_alt(&values)
}
I then wrote the following python function
import numpy as np
def foobar(cb):
vals = cb(np.arange(225000))
return pyo3test.py_test_function_alt(vals)
This function still executes in roughly 20ms.
the GIL is not your problem here, the python interpreter is slow, it takes those 20 milliseconds to invoke the callback in pure python, without any ffi.
redesign your API interface to not make 225_000
calls to python, instead take a python array or numpy arrays, which are specifically designed for passing data to a native API.