rust

Pass-by-value structs in Rust


I working with a large volume of card data and as such I need to store massive (billions) numbers of cards in memory at a time. To keep memory usage down I want to store cards as a u8 where the most significant four bits are the rank and the least significant four bits are the suit, which allows me to sort cards by rank without pulling out the rank. When I do need to pull out ranks and suits, each of these can also be stored as u8s.

However, I don't want to be passing around just u8s because this loses me the benefits of Rust's type system. What I'd like is to have different types that are u8s under the hood, but are named Card, Suit, and Rank, and fail type checking if you assign the wrong one.

The solution I came up with is this:

struct Rank { internal: u8 }
struct Suit { internal: u8 }
struct Card { internal: u8 }

This allows me to also do stuff like this:

impl Card {
  pub fn rank(&self) -> Rank {
    Rank { internal: self.internal >> 4 }
  }
}

The problem is, I believe Rust structs are pass-by-reference, which is going to massively increase my memory usage if I'm storing billions of cards, because a 64-bit pointer is much larger than an 8-bit u8.

My question is, is simply adding #[derive(Clone, Copy)] sufficient to make these structures pass-by-value? Is there some other, better way of achieving what I want to achieve?


Solution

  • tl;dr: Write clean code that best represents your intent, and trust the optimizer to handle common cases well. Make further optimizations based on specific hotspots + data.

    The problem is, I believe Rust structs are pass-by-reference, which is going to massively increase my memory usage if I'm storing billions of cards, because a 64-bit pointer is much larger than an 8-bit u8.

    The author of a function gets to decide how a struct is passed, at a language level. I'm careful to distinguish the language level from the resulting codegen for reasons we'll see later in this answer.

    fn takes_by_value(c: Card) { ... }
    fn takes_by_reference(c: &Card) { ... } 
    
    impl Card {
      // Takes by reference at a language level
      pub fn rank_by_ref(&self) -> Rank {
        Rank { internal: self.internal >> 4 }
      }
    
      // Takes by value at a language level (note lack of &)
      pub fn rank_by_value(self) -> Rank {
        Rank { internal: self.internal >> 4 }
      }
    }
    

    is simply adding #[derive(Clone, Copy)] sufficient to make these structures pass-by-value?

    The structure could have been passed by value regardless. One of the fundamental low level design considerations of Rust is that objects can be moved1 cheaply by-value, whether they are enums, ints, structs of ints, vecs, maps, etc..., and that move is a bitwise move (unlike C++ where a move ctor with user code gets involved)

    However, without deriving Copy, passing an object by value will move it, meaning that you "give it away" to the function:

    fn main() {
      let c = Card { internal: 16 };
      // suppose that Rank implements Debug so we can print it with {:?}
      println!("{:?}", c.rank_by_value()); // This call moves out of `c`
      println!("{:?}", c.rank_by_value()); // This will fail if Copy is not implemented for Card
    }
    

    If we run this, the first call to rank_by_value consumed c so it's not valid for the second call:

    error[E0382]: use of moved value: `c`
      --> src/main.rs:22:20
       |
    20 |   let c = Card { internal: 16 };
       |       - move occurs because `c` has type `Card`, which does not implement the `Copy` trait
    21 |   println!("{:?}", c.rank_by_value()); // This call moves out of `c`
       |                      --------------- `c` moved due to this method call
    22 |   println!("{:?}", c.rank_by_value()); // This will fail if Copy is not implemented for Card
       |                    ^ value used here after move
       |
    note: `Card::rank_by_value` takes ownership of the receiver `self`, which moves `c`
      --> src/main.rs:14:24
       |
    14 |   pub fn rank_by_value(self) -> Rank {
       |                        ^^^^
    

    It is this that deriving Copy solves: each pass-by-value no longer invalidates the previous value.


    a 64-bit pointer is much larger than an 8-bit u8.

    A pointer is, a reference need not be, depending on what the optimizer does, and (especially for this example code), it's very easy for the optimizer to optimize this to produce by-value code - the optimizer looks at the caller and the function itself, and produces code with the same end-result behavior, but without ever needing to create a pointer, or even put the card in memory (it could live in a CPU register).

    #[no_mangle]
    pub fn card_rank(c: Card) -> Rank {
        c.rank_by_ref()
    }
    

    (even without Card: Copy) becomes nothing more than:

    card_rank:
            mov     eax, edi  # move input (customarily in EDX/RDX) to EAX
            shr     al, 4     # shift it right
            ret               # and return it, it's already in AL and the caller expects it in AL
    

    Notice no pointers or indirect loads. For what it's worth, even if you did have a 64-bit pointer, it would only take up a register during a call to rank, as opposed to there being a 64-bit pointer taking up 64 bits of space for every single Card in memory.

    In fact, it's extremely idiomatic for Rust code to eventually deal with these references efficiently, making this a very valuable optimization - even using standard library iterators and similar you'll often pass references at the language level, yet you'll get efficient by-value code generated. For example, the following code goes through multiple structs representing iterators, adapters, etc while passing &Card and &Rank references:

    impl Rank {
        pub fn is_rank_1(&self) -> bool {
            self.internal == 1
        }
    }
    
    #[no_mangle]
    pub fn count_rank_1(cards: &[Card]) -> usize {
        cards.iter()
             .map(|c: &Card| c.rank_by_ref())
             .filter(|r: &Rank| r.is_rank_1())
             .count()
    }
    

    yet it becomes efficient, vectorized assembly looking at the input slice and treating each card as a value, rather than painstakingly creating a pointer for each element and following it.

    1 of course, not while borrowed, since that would invalidate the references