rustunsaferayonrust-ndarray

Rust get mutable reference to each element of an ndarray in parallel


I am working on a parallel matrix multiplication code in Rust, where I want to compute every element of the product in parallel. I use ndarrays to store my data. Thus, my code would be something alone the lines

fn mul(lhs: &Array2<f32>, rhs: &Array2<f32>) -> Array2<f32> {
   let N = lhs.raw_size()[0];
   let M = rhs.raw_size()[1];
   let mut result = Array2::zeros((N,M));
   
   range_2d(0..N,0..M).par_iter().map(|(i, j)| {
      // load the result for the (i,j) element into 'result'
   }).count();

   result
}

Is there any way to achieve this?


Solution

  • You can create a parallel iterator this way:

    use rayon::prelude::*;
    
    pub fn mul(lhs: &Array2<f32>, rhs: &Array2<f32>) -> Array2<f32> {
        let n = lhs.raw_dim()[0];
        let m = rhs.raw_dim()[1];
        let mut result = Array2::zeros((n, m));
    
        result
            .axis_iter_mut(Axis(0))
            .into_par_iter()
            .enumerate()
            .flat_map(|(n, axis)| {
                axis.into_slice()
                    .unwrap()
                    .par_iter_mut()
                    .enumerate()
                    .map(move |(m, item)| (n, m, item))
            })
            .for_each(|(n, m, item)| {
                // Do the multiplication.
                *item = n as f32 * m as f32;
            });
    
        result
    }