pythonrustpyo3

How to overwrite method in Python used in another method in Rust?


I have a struct with two methods return_modified_value and modify_value in Rust. return_modified_value uses modify_value. Rust implementation of modify_value doesn't matter because I want to implement this method in class that will inherit after my struct in Python. Can I make a some sort of placeholder for modify_value or reimplement it in Python?

I tried this in Rust:

use pyo3::prelude::*;

#[pyclass(subclass)]
pub struct ListElement {
    #[pyo3(get,set)]
    pub value_sum: f32
}

#[pymethods]
impl ListElement {
    #[new]
    fn new(value_sum: f32) -> Self {
        ListElement { value_sum }
    }
    
    fn return_modified_value(&mut self) -> f32 {
        self.modify_value();
        self.value_sum
    }

    fn modify_value(&mut self) {
        self.value_sum = self.value_sum.powf(2.0)
    }
}

#[pymodule]
fn my_rust_module(_py: Python, m: &PyModule) -> PyResult<()> {
    m.add_class::<ListElement>()?;
    Ok(())
}

it compiles using maturin develop, then I use this code in Python:

from my_rust_module import ListElement

class PythonListElement(ListElement):
    def __new__(cls, value_sum):
        return super().__new__(cls, value_sum)

    def modify_value(self):
        self.value_sum = 3.0*self.value_sum

ple = PythonListElement(2)
print(ple.return_modified_value())

It prints 4.0, and I want it to print 6.0 (3*2). This is minimal reproducible example, my problem is more complex but I think this bit of code is enough. If it matters, code equivalent to modify_value in my true problem is a method that takes custom object as an argument and modifies it.


Solution

  • Rust resolves the name modify_value(), not Python, and therefore it statically calls the Rust function.

    If you want to call this function via Python, you need to use the call_method() interface of PyO3:

    fn return_modified_value(this: &PyCell<Self>) -> PyResult<f32> {
        this.call_method0("modify_value")?;
    
        let this = this.borrow_mut();
        Ok(this.value_sum)
    }