rustrust-ndarray

Rust ndarray: Mutating axis_iters


I am trying to implement somewhat "batched" matrix multiplication in rust using ndarray. Therefore I am trying to combine some .axis_iters and especially update a "result"-tensors Axis-iters accordingly. For me a problem occurs already when trying something as easy as mutating the "tensor-slices" of a simple array like:

let mut c: ArrayBase<ndarray::OwnedRepr<i32>, Dim<[usize; 2]>> = array![
    [1, 2],
    [1, 2]
];
let d: ArrayBase<ndarray::OwnedRepr<i32>, Dim<[usize; 1]>> = array![1, 1];

c.axis_iter_mut(Axis(1)).for_each(|x| x = d);

The compiler complains at d's position in the last line:

mismatched types
expected struct `ndarray::ArrayBase<ViewRepr<&mut i32>, _>`
   found struct `ndarray::ArrayBase<OwnedRepr<i32>, _>

I am new to rust and not sure, what to do here right now. I see that the types do not match, but I do not know how to set it up in a way that they do and c's columns get updated / replaced as intended.

Dereferencing x with *x in the last line also does not work.

I also had a look at Updating a row of a matrix in rust ndarray but I could not figure out how to get it to work with the .axit_iter_mut.

Please note, that I explicitly want this updates to happen in an axis_iter-way because this is needed for my actual goal of batched matmul.


Solution

  • There is a method that does what you want: assign():

    c.axis_iter_mut(Axis(1)).for_each(|mut x| x.assign(&d));