rustpyo3

How to use a trait as parameter of pymethod in PyO3


I am a beginner with PyO3, and I want to know how to use Box<MyTrait> as the input for a function in PyO3. I have already looked at the official tutorial Appendix C: Trait bounds, but my situation is a bit different as I do not wish to call a python class within rust. Can someone provide some help? Much appreciated.

Here is a specific code example, I had defined a trait SDF:

pub trait SDF: Send + Sync {
    fn distance(&self, p: Vec3f) -> f32;
}

Now, I aim to implement a Union that can be invoked within python bounds:

use pyo3::prelude::*;

pub struct Union {
    a: Box<dyn SDF>,
    b: Box<dyn SDF>,
}

#[pymethods]
impl Union {
    pub fn new(a: Box<dyn SDF>, b: Box<dyn SDF>) -> Union {
        // ...
    }
}

// Another struct implements the trait SDF, which is hoped to be helpful
pub struct Sphere {
    center: Vec3f,
    radius: f32,
}

impl SDF for Sphere {
    fn distance(&self, p: Vec3f) -> f32 {
        (p - c).norm() - r
    }
}

My question is, how do I implement the new method?


Solution

  • You cannot expose dyn Trait directly to Python. You will have to create a wrapper.

    Basically, create a #[pyclass] pub struct DynSDF(Box<dyn SDF + Send>);. This is a normal pyclass you can access from Python.

    Provide various constructors to it, each for each struct implementing the trait. For example, a constructor for Sphere, either taking Sphere directly (if it is a #[pyclass], in this case it can also be #[new]), or a sphere() method taking as parameters what requires to build a Sphere.

    If you also need to call the trait methods from Python, add #[pymethods] to DynSDF calling the trait methods.

    Then, inside Union::new(), take DynSDF. Python can pass it since it is a Python class.