I am working on a parallel matrix multiplication code in Rust, where I want to compute every element of the product in parallel. I use ndarray
s to store my data. Thus, my code would be something alone the lines
fn mul(lhs: &Array2<f32>, rhs: &Array2<f32>) -> Array2<f32> {
let N = lhs.raw_size()[0];
let M = rhs.raw_size()[1];
let mut result = Array2::zeros((N,M));
range_2d(0..N,0..M).par_iter().map(|(i, j)| {
// load the result for the (i,j) element into 'result'
}).count();
result
}
Is there any way to achieve this?
You can create a parallel iterator this way:
use rayon::prelude::*;
pub fn mul(lhs: &Array2<f32>, rhs: &Array2<f32>) -> Array2<f32> {
let n = lhs.raw_dim()[0];
let m = rhs.raw_dim()[1];
let mut result = Array2::zeros((n, m));
result
.axis_iter_mut(Axis(0))
.into_par_iter()
.enumerate()
.flat_map(|(n, axis)| {
axis.into_slice()
.unwrap()
.par_iter_mut()
.enumerate()
.map(move |(m, item)| (n, m, item))
})
.for_each(|(n, m, item)| {
// Do the multiplication.
*item = n as f32 * m as f32;
});
result
}