rustinterior-mutability

How to mutably reference one of several things at a time in Rust?


Say I want a train that can switch between one of 2 tracks at any time and write a u8 at its current position. Naively something like this:

struct Train<'a> {
  track_a: &'a mut [u8],
  track_b: &'a mut [u8],
  current_track: &'a mut [u8], // either track_a or track_b
  idx: usize,
}

impl<'a> Train<'a> {
  pub fn new(track_a: &'a mut [u8], track_b: &'a mut [u8]) -> Self {
    Self {
     track_a,
     track_b,
     idx: 0,
     current_track: track_a,
    }
  }

  pub fn toggle_track(&mut self) {
    if self.current_track == self.track_a {
      self.current_track = self.track_b;
    } else {
      self.current_track = self.track_a;
    }
  }

  pub fn write(&mut self, byte: u8) {
    // must be fast - can't waste time choosing track here
    self.current_track[self.idx] = byte;
    self.idx += 1;
  }
}

Importantly, we can't waste time deciding which track we're currently on in write. Of course, the above code doesn't compile because we mutably borrow track_a and track_b multiple times.

How might I get something like this to work in Rust? I tried using RefCells for track_a and track_b, but realized even that doesn't make sense since even immutable Refs would give mutable access to the underlying bytes.

Is unsafe Rust the only way to implement this data structure?


Solution

  • You can do this without interior mutability, without unsafe, and without even a separate current_track field if you do it this way:

    struct Train<'a> {
        track_a: &'a mut [u8], // a.k.a, the current track
        track_b: &'a mut [u8],
        idx: usize,
    }
    
    impl<'a> Train<'a> {
        pub fn new(track_a: &'a mut [u8], track_b: &'a mut [u8]) -> Self {
            Self {
                track_a,
                track_b,
                idx: 0,
            }
        }
    
        pub fn toggle_track(&mut self) {
            std::mem::swap(&mut self.track_a, &mut self.track_b);
        }
    
        pub fn write(&mut self, byte: u8) {
            // must be fast - can't waste time choosing track here
            self.track_a[self.idx] = byte;
            self.idx += 1;
        }
    }
    

    The trick here of course is to use std::mem::swap to do the toggling by swapping the two tracks (one of which we consider the "current" track). The only downside of this would be if you needed to know which one was the "first" and "second" ones given initially (since that may get lost in the toggling), but that can be done via a separate boolean while still leaving write branchless.