rustreferencemutable

Abstracting mutable/immutable references in rust


I need to get rid of duplication in this code:

pub struct Memory {
    layout: MemoryLayout,
    rom: Vec<u8>,
    ram: Vec<u8>,
}

impl Memory {
    pub fn get_mem_vec_ref(&self, address: u32) -> Result<&Vec<u8>, RiscvError> {
        // ...

        let mem_vec_ref = match address {
            addr if (rom_start..rom_end).contains(&addr) => Ok(&self.rom),
            addr if (ram_start..ram_end).contains(&addr) => Ok(&self.ram),
            addr => Err(RiscvError::MemoryAlignmentError(addr)),
        }?;

        return Ok(mem_vec_ref);
    }

    pub fn get_mem_vec_mut_ref(&mut self, address: u32) -> Result<&mut Vec<u8>, RiscvError> {
        // ...

        let mem_vec_ref = match address {
            addr if (rom_start..rom_end).contains(&addr) => Ok(&mut self.rom),
            addr if (ram_start..ram_end).contains(&addr) => Ok(&mut self.ram),
            addr => Err(RiscvError::MemoryAlignmentError(addr)),
        }?;

        return Ok(mem_vec_ref);
    }
}

How can I abstract using mutable vs immutable reference to self? Can Box or RefCell be helpful in this case?


Solution

  • Since you're dealing with references in both cases, then you can define a generic function, where T would be either &Vec<u8> or &mut Vec<u8>. So you can do something like this:

    fn get_mem<T>(address: u32, rom: T, ram: T) -> Result<T, RiscvError> {
        // ...
    
        match address {
            addr if (rom_start..rom_end).contains(&addr) => Ok(rom),
            addr if (ram_start..ram_end).contains(&addr) => Ok(ram),
            addr => Err(RiscvError::MemoryAlignmentError(addr)),
        }
    }
    
    impl Memory {
        pub fn get_mem_vec_ref(&self, address: u32) -> Result<&Vec<u8>, RiscvError> {
            // ...
    
            let mem_vec_ref = get_mem(address, &self.rom, &self.ram)?;
    
            return Ok(mem_vec_ref);
        }
    
        pub fn get_mem_vec_mut_ref(&mut self, address: u32) -> Result<&mut Vec<u8>, RiscvError> {
            // ...
    
            let mem_vec_ref = get_mem(address, &mut self.rom, &mut self.ram)?;
    
            return Ok(mem_vec_ref);
        }
    }
    

    Now, obviously you need to modify get_mem() to account for rom_start, rom_end, ram_start, ram_end. If you want to avoid having to pass 100 fields to get_mem(), then it might be worth introducing a newtype to deal with addresses instead, e.g. something like:

    struct Addr {
        // ...
    }
    
    impl Addr {
        fn get_mem<T>(&self, rom: T, ram: T) -> Result<T, RiscvError> {
            // ...
        }
    }