I am trying to implement somewhat "batched" matrix multiplication in rust using ndarray
. Therefore I am trying to combine some .axis_iter
s and especially update a "result"-tensors Axis-iters accordingly. For me a problem occurs already when trying something as easy as mutating the "tensor-slices" of a simple array like:
let mut c: ArrayBase<ndarray::OwnedRepr<i32>, Dim<[usize; 2]>> = array![
[1, 2],
[1, 2]
];
let d: ArrayBase<ndarray::OwnedRepr<i32>, Dim<[usize; 1]>> = array![1, 1];
c.axis_iter_mut(Axis(1)).for_each(|x| x = d);
The compiler complains at d
's position in the last line:
mismatched types
expected struct `ndarray::ArrayBase<ViewRepr<&mut i32>, _>`
found struct `ndarray::ArrayBase<OwnedRepr<i32>, _>
I am new to rust and not sure, what to do here right now. I see that the types do not match, but I do not know how to set it up in a way that they do and c
's columns get updated / replaced as intended.
Dereferencing x
with *x
in the last line also does not work.
I also had a look at Updating a row of a matrix in rust ndarray but I could not figure out how to get it to work with the .axit_iter_mut
.
Please note, that I explicitly want this updates to happen in an axis_iter
-way because this is needed for my actual goal of batched matmul.
There is a method that does what you want: assign()
:
c.axis_iter_mut(Axis(1)).for_each(|mut x| x.assign(&d));