I'm trying to implement a vector class in rust for my math library.
#[pyclass]
struct Vec2d {
#[pyo3(get, set)]
x: f64,
#[pyo3(get, set)]
y: f64
}
But I can't figure out how I can overload the standard operators (+, -, *, /)
I Tried implementing the Add trait from std::ops with no luck
impl Add for Vec2d {
type Output = Vec2d;
fn add(self, other: Vec2d) -> Vec2d {
Vec2d{x: self.x + other.x, y: self.y + other.y }
}
}
I also tried adding __add__
method to the #[pymethods] block
fn __add__(&self, other: & Vec2d) -> PyResult<Vec2d> {
Ok(Vec2d{x: self.x + other.x, y: self.y + other.y })
}
but still does not work.
With the second approach I can see that the method is there, but python doesn't recognize it as operator overload
In [2]: v1 = Vec2d(3, 4)
In [3]: v2 = Vec2d(6, 7)
In [4]: v1 + v2
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-4-08104d7e1232> in <module>()
----> 1 v1 + v2
TypeError: unsupported operand type(s) for +: 'Vec2d' and 'Vec2d'
In [5]: v1.__add__(v2)
Out[5]: <Vec2d object at 0x0000026B74C2B6F0>
As per the PyO3
Documentation,
Python's object model defines several protocols for different object behavior, like sequence, mapping or number protocols. PyO3 defines separate traits for each of them. To provide specific python object behavior you need to implement the specific trait for your struct.
Important note, each protocol implementation block has to be annotated with #[pyproto
] attribute.
__add__
, __sub__
etc are defined within PyNumberProtocol
Trait.
So you could implement PyNumberProtocol
for your Vec2d
struct to overload standard operations.
#[pyproto]
impl PyNumberProtocol for Vec2d {
fn __add__(&self, other: & Vec2d) -> PyResult<Vec2d> {
Ok(Vec2d{x: self.x + other.x, y: self.y + other.y })
}
}
This solution is not tested, For the complete working solution check @Neven V's answer.