rustparallel-processingborrow-checkerunsaferayon

How to best parallelize code modifying several slices of the same Rust vector?


Let's say we want to double (in-place) each element in each of the slices of a vector, where the slices are defined by a list of pairs - (start, end) positions. The following code expresses the intent idiomatically, but doesn't compile because of the mutable borrow of the vector inside the parallel for_each:

use rayon::prelude::*;

fn main() {
    let mut data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
    let slice_pairs = vec![(0, 3), (4, 7), (8, 10)];

    slice_pairs.into_par_iter().for_each(|(start, end)| {
        let slice = &mut data[start..end];
        for elem in slice.iter_mut() {
            *elem *= 2;
        }
    });

    println!("{:?}", data);
}

There is a real potential for data races here - to rule them out, you need to check if the slices overlap. The question is what's the best way to do this in Rust, either thru unsafe code or a safe API. The following code uses unsafe to "go ahead and do this"; my question is if there's a better way than the below (which transmutes the vector's base pointer to an i64 and back to "blind" the borrow checker to the problem.)

use rayon::prelude::*;
use std::mem;

fn main() {
    let mut data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
    let slice_pairs = vec![(0, 4), (4, 7), (7, 10)];

    let ptr_outer = data.as_mut_ptr();
    let ptr_int : i64 = unsafe { mem::transmute(ptr_outer) };

    slice_pairs.into_par_iter().for_each(|(start, end)| {
        unsafe {
            let ptr : *mut i32 = mem::transmute(ptr_int);
            let slice = std::slice::from_raw_parts_mut(ptr.add(start), end - start);

            for elem in slice.iter_mut() {
                *elem *= 2;
            }
        }
    });

    println!("{:?}", data);
}

Solution

  • You can use split_at_mut() to break up the slice into multiple slices using safe code:

    fn split_many<'a, T>(mut slice: &'a mut [T], regions: &[(usize, usize)]) -> Vec<&'a mut [T]> {
        let mut regions = regions.to_vec();
        regions.sort_by_key(|&(b, _e)| b);
        let mut ret = vec![];
        let mut offset = 0;
        for (b, e) in regions {
            assert!(b >= offset && e >= b); // prohibit overlaps
            let (chosen, rest) = slice.split_at_mut(e - offset);
            ret.push(&mut chosen[b - offset..]);
            offset = e;
            slice = rest;
        }
        ret
    }
    

    With that helper in place, you can express parallel in-place manipulation in the "obvious" way:

    fn main() {
        let mut data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
        let slice_pairs = vec![(0, 3), (4, 7), (8, 10)];
    
        split_many(&mut data, &slice_pairs)
            .into_par_iter()
            .for_each(|region| {
                for elem in region.iter_mut() {
                    *elem *= 2;
                }
            });
    
        println!("{:?}", data);
    }
    

    Playground

    Note that as long as the regions are represented by arbitrary indices supplied at run-time, it's necessary to make an initial pass through them to ensure they don't overlap (split_many() panics if it detects overlap). Failing to do so would be unsound, as simply choosing overlapping regions would cause undefined behavior. However, if you control the code that generates the regions and know that they don't overlap, you could make a faster unsafe version of split_many(). Relying on external guarantees, it doesn't need to either sort or allocate a new set of regions, and can even return a ParallelIterator directly:

    /// Split `slice` into `regions` and iterate over them in parallel.
    /// Safety: regions must not overlap.
    unsafe fn split_many_unchecked<'a, T: Send + Sync>(
        slice: &'a mut [T],
        regions: &'a [(usize, usize)],
    ) -> impl ParallelIterator<Item = &'a mut [T]> + 'a {
        struct Wrap<T>(*mut T);
        unsafe impl<T> Sync for Wrap<T> {}
        unsafe impl<T> Send for Wrap<T> {}
        let slice = Wrap(slice.as_mut_ptr());
        regions.par_iter().map(move |&(b, e)| {
            let _ = &slice; // prevent closure from capturing slice.0
            std::slice::from_raw_parts_mut(slice.0.add(b), e - b)
        })
    }
    

    Playground