pythonruststructpyo3

How to overwrite method in Python used in another method with mutable arguments in Rust?


I have two structs: Struct and InnerStruct. Struct has two methods: modify_object that uses modify_object_inner. Rust implementation of modify_object_inner doesn't matter because I want to implement this method in class that will inherit after my Struct in Python. Function modify_object modifies field of Struct of type InnerStruct. I made this code and it compiles:

use pyo3::prelude::*;
use pyo3::types::PyDict;

#[pyclass(subclass)]
#[derive(Clone)]
pub struct InnerStruct {
    #[pyo3(get,set)]
    pub field: i32
}

#[pyclass(subclass)]
pub struct Struct {
    #[pyo3(get,set)]
    pub inner_struct: InnerStruct
}

#[pymethods]
impl InnerStruct {
    #[new]
    fn new(field: i32) -> Self {
        InnerStruct {field}
    }
}

// I had to implement this because of error "error[E0277]: the trait bound `&mut InnerStruct: ToPyObject` is not satisfied"
impl ToPyObject for &mut InnerStruct {
    fn to_object(&self, py: Python<'_>) -> PyObject {
        let dict = PyDict::new(py);
        dict.set_item("field", self.field).expect("Failed to set field in dictionary");
        dict.into()
    }
}

#[pymethods]
impl Struct {
    #[new]
    fn new(inner_struct: InnerStruct) -> Self {
        Struct { inner_struct}
    }
    
    fn modify_object(this: &PyCell<Self>) -> () {
        Python::with_gil(|py| {
            let inner_struct = &mut this.borrow_mut().inner_struct;
            let kwargs = PyDict::new(py);
            kwargs.set_item("object_to_modify", inner_struct).expect("Error with set_item");
            this.call_method("modify_object_inner", (), Some(kwargs)).expect("Error with call_method");
        });
    }
    fn modify_object_inner(&mut self, object_to_modify: &mut InnerStruct) {
        object_to_modify.field = -1
    }
    
}

#[pymodule]
fn my_rust_module(_py: Python, m: &PyModule) -> PyResult<()> {
    m.add_class::<Struct>()?;
    m.add_class::<InnerStruct>()?;
    Ok(())
}

But when I tested it with this Python code:

from my_rust_module import InnerStruct, Struct
import os

os.environ['RUST_BACKTRACE'] = '1'


class PythonStruct(Struct):
    def __new__(cls, inner_struct):
        return super().__new__(cls, inner_struct)


inner_struct = InnerStruct(0)
ps = PythonStruct(inner_struct)
ps.modify_object()
print(ps.inner_struct.field)  # without overwriting should print -1


class PythonListElement(Struct):
    def __new__(cls, inner_struct):
        return super().__new__(cls, inner_struct)

    def modify_object_inner(self, inner_struct):
        inner_struct.field = 1


inner_struct = InnerStruct(0)
ps = PythonStruct(inner_struct)
ps.modify_object()
print(ps.inner_struct.field)  # without overwriting should print 1

I got:

thread '<unnamed>' panicked at 'Error with call_method: PyErr { type: <class 'RuntimeError'>, value: RuntimeError('Already borrowed'), traceback: None }', src\lib.rs:46:71

If someone knows the answer, please post additionally a source of your knowledge (for example link to appropriate part of doc), because I'm quite lost and don't know how to find answers myself.


Solution

  • PyO3 still needs to keep Rust's borrowing rules. It cannot let you hold two mutable references to the same object at the same time in Rust. But in Python all references are mutable, so how does it do that?

    The answer is PyCell. It is somewhat like RefCell: a type that dynamically keeps track of borrowing state. When you call a Rust function that requires &mut self or &self, PyO3 tries to borrow the value in the PyCell (mutably or immutably), and raises and exception if it cannot because it is already borrowed.

    In modify_object(), you take this: &PyCell<Self> and borrow_mut() it. While this reference is active, you cannot anymore borrow this object. When you call modify_object_inner(), PyO3 tries to borrow the same object immutably for &mut self, and fails. So the call to call_method() fails, and because you call expect() on it, the method panics.

    There are two possible solutions:

    let guard = this.borrow_mut();
    let inner_struct = &mut guard.inner_struct;
    let kwargs = PyDict::new(py);
    kwargs
        .set_item("object_to_modify", inner_struct)
        .expect("Error with set_item");
    drop(guard);
    this.call_method("modify_object_inner", (), Some(kwargs))
        .expect("Error with call_method");
    

    Here, when you call the Rust method, the object is no longer mutably borrowed.

    However, your code has an unrelated error: you implemented ToPyObject for &mut InnerStruct, but this code just creates a dict. The dict has no correlation to the original object, and mutating either won't influence the other.

    What you need is to hold a Py<InnerStruct> in your struct instead of just InnerStruct. &Py implements ToPyObject, so you can pass it to Python and watch the changes.

    You also no longer needs borrow_mut(), only borrow(), and if you notice that in modify_object_inner() you also only need &self (or even just &PyCell<Self>, that doesn't borrow the object at all), you can keep the Rust implementation of modify_object_inner() (as two shared references are allowed to live together).

    Another two points is that you don't need Python::with_gil(), you can (and should) instead take py: Python<'_> as parameter and PyO3 will automatically do everything required, and that using expect() and panicking is a bad idea: you should return PyResult and propagate the error. If you panic, you get a special PanicException that shouldn't be caught. But if you propagate the error, you get a normal exception. While panicking on the dictionary set may be fine (because it will fail if the item is not hashable, and it is), you don't want to panic if e.g. someone removed the modify_object_inner() method in a subclass, you want a normal exception that can be caught.

    Here's the proposed Rust code:

    use pyo3::prelude::*;
    use pyo3::types::PyDict;
    
    #[pyclass(subclass)]
    pub struct InnerStruct {
        #[pyo3(get, set)]
        pub field: i32,
    }
    
    #[pyclass(subclass)]
    pub struct Struct {
        #[pyo3(get, set)]
        pub inner_struct: Py<InnerStruct>,
    }
    
    #[pymethods]
    impl InnerStruct {
        #[new]
        fn new(field: i32) -> Self {
            InnerStruct { field }
        }
    }
    
    #[pymethods]
    impl Struct {
        #[new]
        fn new(inner_struct: Py<InnerStruct>) -> Self {
            Struct { inner_struct }
        }
    
        fn modify_object(this: &PyCell<Self>, py: Python<'_>) -> PyResult<()> {
            let inner_struct = &this.borrow().inner_struct;
            let kwargs = PyDict::new(py);
            kwargs.set_item("object_to_modify", inner_struct)?;
            this.call_method("modify_object_inner", (), Some(kwargs))?;
            Ok(())
        }
        fn modify_object_inner(&self, object_to_modify: &mut InnerStruct) {
            object_to_modify.field = -1;
        }
    }
    
    #[pymodule]
    fn my_rust_module(_py: Python, m: &PyModule) -> PyResult<()> {
        m.add_class::<Struct>()?;
        m.add_class::<InnerStruct>()?;
        Ok(())
    }
    

    And the Python code:

    from my_rust_module import InnerStruct, Struct
    import os
    
    os.environ['RUST_BACKTRACE'] = '1'
    
    
    class PythonStruct(Struct):
        def __new__(cls, inner_struct):
            return super().__new__(cls, inner_struct)
    
    
    inner_struct = InnerStruct(0)
    ps = PythonStruct(inner_struct)
    ps.modify_object()
    print(ps.inner_struct.field)  # Prints -1
    
    
    class PythonListElement(Struct):
        def __new__(cls, inner_struct):
            return super().__new__(cls, inner_struct)
    
        # Notice the name! It should match the name of the keyword argument you're passing!
        def modify_object_inner(self, object_to_modify):
            object_to_modify.field = 1
    
    
    inner_struct = InnerStruct(0)
    # And here there was a typo, it should be `PythonListElement`, not `PythonStruct`!
    ps = PythonListElement(inner_struct)
    ps.modify_object()
    print(ps.inner_struct.field) # Prints 1