rustrayonrust-ndarray

Rust returnig iterator with trait bound on its items


I am trying to write a common interface for different types of matrices that provides a way to mutably iterate their rows and modify them. I have the following matrix types:

struct NdArrayMatrix {
    matrix: Array2<f32>,
}

struct ByteMatrix<'a> {
    data: &'a mut [u8],
    rows: usize,
    cols: usize,
}

Where the first one is just a RAM-stored matrix, and the second is memory mapped, using the MMap library, but for convenience, I omit those details. First, I made a trait to be able to modify both of them using the same interface:

trait ReadWrite
{
    fn rw_read(&self, i: usize, j: usize) -> f32;
    fn rw_write(&mut self, i: usize, j: usize, val: f32);
}

Then, I've created a trait that produces a rayon::iter::IndexedParallelItertor from both of these:

trait Sliceable<'a>
{
    type Output: IndexedParallelIterator;

    fn rows_par_iter(&'a mut self ) -> Self::Output;
}

Up to this point works everything fine. But when I want to use these in a generic context, such as:

fn<'a, T> slice_and_write(matrix: T)
where T: Sliceable<'a>
{
    T.rows_par_iter()
     .map(|mut row| {
          row.rw_write(...);
     })
     ...
}

I run into problems. It is obvious, that row, in this case, doesn't implement ReadWrite so no surprise there. So what I tried to do, is to create an iterator trait based on IndexedParallelItertor:

trait RwIterator: IndexedParallelIterator {
    type Item: ReadWrite;
}

and modify Sliceable:

trait Sliceable<'a>
{
    type Output: RwIterator;

    fn rows_par_iter(&'a mut self ) -> Self::Output;
}

Running this I get the error:

   |  row.rw_write(...);
   |      ^^^^^^^^ method not found in `<<T as Sliceable<'a>>::Output as ParallelIterator>::Item`

Which is, again, fairly obvious. I suspect that the map function does only require the trait bound ParallelIterator, hence can't take advantage of the trait RwIterator.

My question is: Is there any way around this problem, or an alternate way for doing this?

EDIT: Here is a minimal reproducible code example, only using one of the matrix structures.

use ndarray::Array2;
use rayon::prelude::*;
use ndarray::Axis;
use ndarray::parallel::Parallel;
use ndarray::Dim;
use ndarray::iter::AxisIterMut;
use rayon::iter::ParallelIterator;
use ndarray::ViewRepr;
use ndarray::ArrayBase;

struct NdArrayMatrix {
    matrix: Array2<f32>,
}

impl NdArrayMatrix {
    pub fn new() -> Self {
        let matrix = Array2::zeros((10, 10));
        
        Self {
            matrix,
        }
    }
}

trait ReadWrite
{
    fn rw_read(&self, i: usize, j: usize) -> f32;
    fn rw_write(&mut self, i: usize, j: usize, val: f32);
}

impl ReadWrite for NdArrayMatrix {
    fn rw_read(&self, i: usize, j: usize) -> f32 {
        self.matrix[[i, j]]
    }

    fn rw_write(&mut self, i: usize, j: usize, val: f32) {
        self.matrix[[i, j]] = val;
    }
}

impl ReadWrite for ArrayBase<ViewRepr<&mut f32>, Dim<[usize; 1]>> {
    fn rw_read(&self, i: usize, j: usize) -> f32 {
        self[j]
    }

    fn rw_write(&mut self, i: usize, j: usize, val: f32) {
        self[j] = val;
    }
}

trait RwIterator: IndexedParallelIterator {
    type Item: ReadWrite;
}

impl<'a> RwIterator for Parallel<AxisIterMut<'a, f32, Dim<[usize; 1]>>> {
    type Item =  ArrayBase<ViewRepr<&'a mut f32>, Dim<[usize; 1]>>;
}

trait Sliceable<'a>
{
    type Output: RwIterator;

    fn rows_par_iter(&'a mut self ) -> Self::Output;
}

impl<'a> Sliceable<'a> for NdArrayMatrix {
    type Output = Parallel<AxisIterMut<'a, f32, Dim<[usize; 1]>>>;

    fn rows_par_iter(&'a mut self) -> Self::Output {
        self.matrix
            .axis_iter_mut(Axis(0))
            .into_par_iter()
    }
}

fn main() {
    let mut matrix: NdArrayMatrix = NdArrayMatrix::new();

    test(matrix);
}

fn test<'a, T> (matrix: T)
where T: Sliceable<'a> + ReadWrite
{
    matrix.rows_par_iter()
        .map(|mut row| {
            row.rw_write(0, 0, 0.0);
        }).count();
}

Solution

  • Your code is 90% there.

    The problem you are facing is that RwIterator::Item is ReadWrite, but nowhere does your code constrain that RwIterator::Item has to be the same as the ParallelIterator::Item of the same object.

    To fix this, you can annotate it manually:

    trait RwIterator: IndexedParallelIterator<Item = <Self as RwIterator>::Item> {
        type Item: ReadWrite;
    }
    
    impl<'a> RwIterator for Parallel<AxisIterMut<'a, f32, Dim<[usize; 1]>>> {
        type Item = ArrayBase<ViewRepr<&'a mut f32>, Dim<[usize; 1]>>;
    }
    

    With that, Rust now understands the connection.

    Some other minor adjustments were necessary as well. Here is a version that compiles:

    use ndarray::iter::AxisIterMut;
    use ndarray::parallel::Parallel;
    use ndarray::Array2;
    use ndarray::ArrayBase;
    use ndarray::Axis;
    use ndarray::Dim;
    use ndarray::ViewRepr;
    use rayon::iter::ParallelIterator;
    use rayon::prelude::*;
    
    struct NdArrayMatrix {
        matrix: Array2<f32>,
    }
    
    impl NdArrayMatrix {
        pub fn new() -> Self {
            let matrix = Array2::zeros((10, 10));
    
            Self { matrix }
        }
    }
    
    trait ReadWrite {
        fn rw_read(&self, i: usize, j: usize) -> f32;
        fn rw_write(&mut self, i: usize, j: usize, val: f32);
    }
    
    impl ReadWrite for NdArrayMatrix {
        fn rw_read(&self, i: usize, j: usize) -> f32 {
            self.matrix[[i, j]]
        }
    
        fn rw_write(&mut self, i: usize, j: usize, val: f32) {
            self.matrix[[i, j]] = val;
        }
    }
    
    impl ReadWrite for ArrayBase<ViewRepr<&mut f32>, Dim<[usize; 1]>> {
        fn rw_read(&self, i: usize, j: usize) -> f32 {
            self[j]
        }
    
        fn rw_write(&mut self, i: usize, j: usize, val: f32) {
            self[j] = val;
        }
    }
    
    trait RwIterator: IndexedParallelIterator<Item = <Self as RwIterator>::Item> {
        type Item: ReadWrite;
    }
    
    impl<'a> RwIterator for Parallel<AxisIterMut<'a, f32, Dim<[usize; 1]>>> {
        type Item = ArrayBase<ViewRepr<&'a mut f32>, Dim<[usize; 1]>>;
    }
    
    trait Sliceable<'a> {
        type Output: RwIterator;
    
        fn rows_par_iter(&'a mut self) -> Self::Output;
    }
    
    impl<'a> Sliceable<'a> for NdArrayMatrix {
        type Output = Parallel<AxisIterMut<'a, f32, Dim<[usize; 1]>>>;
    
        fn rows_par_iter(&'a mut self) -> Self::Output {
            self.matrix.axis_iter_mut(Axis(0)).into_par_iter()
        }
    }
    
    fn main() {
        let matrix: NdArrayMatrix = NdArrayMatrix::new();
    
        test(matrix);
    }
    
    fn test<T>(mut matrix: T)
    where
        for<'a> T: Sliceable<'a>,
    {
        matrix
            .rows_par_iter()
            .map(|mut row| {
                row.rw_write(0, 0, 0.0);
            })
            .count();
    }
    

    Little Excurse

    All of this is only necessary because associated types cannot be annotated with trait bounds yet.

    This might change if RFC 2289 gets stabilized at some point.

    Then, you might be able to completely delete the RwIterator trait and specify it like this instead:

    type Output: ParallelIterator<Item: ReadWrite>;
    

    No guarantees on that one, though. I didn't get it to work with the nightly compiler yet.

    You can already kind of emulate that behaviour, with some boilerplate code:

    use ndarray::iter::AxisIterMut;
    use ndarray::parallel::Parallel;
    use ndarray::Array2;
    use ndarray::ArrayBase;
    use ndarray::Axis;
    use ndarray::Dim;
    use ndarray::ViewRepr;
    use rayon::iter::ParallelIterator;
    use rayon::prelude::*;
    
    struct NdArrayMatrix {
        matrix: Array2<f32>,
    }
    
    impl NdArrayMatrix {
        pub fn new() -> Self {
            let matrix = Array2::zeros((10, 10));
    
            Self { matrix }
        }
    }
    
    trait ReadWrite {
        fn rw_read(&self, i: usize, j: usize) -> f32;
        fn rw_write(&mut self, i: usize, j: usize, val: f32);
    }
    
    impl ReadWrite for NdArrayMatrix {
        fn rw_read(&self, i: usize, j: usize) -> f32 {
            self.matrix[[i, j]]
        }
    
        fn rw_write(&mut self, i: usize, j: usize, val: f32) {
            self.matrix[[i, j]] = val;
        }
    }
    
    impl ReadWrite for ArrayBase<ViewRepr<&mut f32>, Dim<[usize; 1]>> {
        fn rw_read(&self, i: usize, j: usize) -> f32 {
            self[j]
        }
    
        fn rw_write(&mut self, i: usize, j: usize, val: f32) {
            self[j] = val;
        }
    }
    
    trait Sliceable<'a> {
        type Item: ReadWrite;
        type Output: ParallelIterator<Item = Self::Item>;
    
        fn rows_par_iter(&'a mut self) -> Self::Output;
    }
    
    impl<'a> Sliceable<'a> for NdArrayMatrix {
        // Sadly needs explicit type annotation for `Item`
        type Item = ArrayBase<ViewRepr<&'a mut f32>, Dim<[usize; 1]>>;
        type Output = Parallel<AxisIterMut<'a, f32, Dim<[usize; 1]>>>;
    
        fn rows_par_iter(&'a mut self) -> Self::Output {
            self.matrix.axis_iter_mut(Axis(0)).into_par_iter()
        }
    }
    
    fn main() {
        let matrix: NdArrayMatrix = NdArrayMatrix::new();
    
        test(matrix);
    }
    
    fn test<T>(mut matrix: T)
    where
        for<'a> T: Sliceable<'a>,
    {
        matrix
            .rows_par_iter()
            .map(|mut row| {
                row.rw_write(0, 0, 0.0);
            })
            .count();
    }