hashmaprust

How to get value from HashMap with two keys via references to both keys?


The way that HashMap implements the get method requires a single immutable borrow. But I want an implementation (for a future trait interface) that accepts the two keys separately like this:

pub struct Table<A: Eq + Hash, B: Eq + Hash> {
    map: HashMap<(A, B), f64>,
}

impl<A: Eq + Hash, B: Eq + Hash> Memory<A, B> for Table<A, B> {
    fn get(&self, a: &A, b: &B) -> f64 {
        let key: &(A, B) = ??;
        *self.map.get(key).unwrap()
    }

    fn set(&mut self, a: A, b: B, v: f64) {
        self.map.insert((a, b), v);
    }
}

The problem is I don't know how to construct a &(A, B) from &A and &B if that's even possible.


Solution

  • This is certainly possible. The signature of get is

    fn get<Q: ?Sized>(&self, k: &Q) -> Option<&V> 
    where
        K: Borrow<Q>,
        Q: Hash + Eq, 
    

    The problem here is to implement a &Q type such that

    1. (A, B): Borrow<Q>
    2. Q implements Hash + Eq

    To satisfy condition (1), we need to think of how to write

    fn borrow(self: &(A, B)) -> &Q
    

    The trick is that &Q does not need to be a simple pointer, it can be a trait object! The idea is to create a trait Q which will have two implementations:

    impl Q for (A, B)
    impl Q for (&A, &B)
    

    The Borrow implementation will simply return self and we can construct a &dyn Q trait object from the two elements separately.


    The full implementation is like this:

    use std::borrow::Borrow;
    use std::collections::HashMap;
    use std::hash::{Hash, Hasher};
    
    // See explanation (1).
    trait KeyPair<A, B> {
        /// Obtains the first element of the pair.
        fn a(&self) -> &A;
        /// Obtains the second element of the pair.
        fn b(&self) -> &B;
    }
    
    // See explanation (2).
    impl<'a, A, B> Borrow<dyn KeyPair<A, B> + 'a> for (A, B)
    where
        A: Eq + Hash + 'a,
        B: Eq + Hash + 'a,
    {
        fn borrow(&self) -> &(dyn KeyPair<A, B> + 'a) {
            self
        }
    }
    
    // See explanation (3).
    impl<A: Hash, B: Hash> Hash for dyn KeyPair<A, B> + '_ {
        fn hash<H: Hasher>(&self, state: &mut H) {
            self.a().hash(state);
            self.b().hash(state);
        }
    }
    
    impl<A: Eq, B: Eq> PartialEq for dyn KeyPair<A, B> + '_ {
        fn eq(&self, other: &Self) -> bool {
            self.a() == other.a() && self.b() == other.b()
        }
    }
    
    impl<A: Eq, B: Eq> Eq for dyn KeyPair<A, B> + '_ {}
    
    // OP's Table struct
    pub struct Table<A: Eq + Hash, B: Eq + Hash> {
        map: HashMap<(A, B), f64>,
    }
    
    impl<A: Eq + Hash, B: Eq + Hash> Table<A, B> {
        fn new() -> Self {
            Table {
                map: HashMap::new(),
            }
        }
    
        fn get(&self, a: &A, b: &B) -> f64 {
            *self.map.get(&(a, b) as &dyn KeyPair<A, B>).unwrap()
        }
    
        fn set(&mut self, a: A, b: B, v: f64) {
            self.map.insert((a, b), v);
        }
    }
    
    // Boring stuff below.
    
    impl<A, B> KeyPair<A, B> for (A, B) {
        fn a(&self) -> &A {
            &self.0
        }
        fn b(&self) -> &B {
            &self.1
        }
    }
    impl<A, B> KeyPair<A, B> for (&A, &B) {
        fn a(&self) -> &A {
            self.0
        }
        fn b(&self) -> &B {
            self.1
        }
    }
    
    //----------------------------------------------------------------
    
    #[derive(Eq, PartialEq, Hash)]
    struct A(&'static str);
    
    #[derive(Eq, PartialEq, Hash)]
    struct B(&'static str);
    
    fn main() {
        let mut table = Table::new();
        table.set(A("abc"), B("def"), 4.0);
        table.set(A("123"), B("456"), 45.0);
        println!("{:?} == 45.0?", table.get(&A("123"), &B("456")));
        println!("{:?} == 4.0?", table.get(&A("abc"), &B("def")));
        // Should panic below.
        println!("{:?} == NaN?", table.get(&A("123"), &B("def")));
    }
    

    Explanation:

    1. The KeyPair trait takes the role of Q we mentioned above. We'd need to impl Eq + Hash for dyn KeyPair, but Eq and Hash are both not object safe. We add the a() and b() methods to help implementing them manually.

    2. Now we implement the Borrow trait from (A, B) to dyn KeyPair + 'a. Note the 'a — this is a subtle bit that is needed to make Table::get actually work. The arbitrary 'a allows us to say that an (A, B) can be borrowed to the trait object for any lifetime. If we don't specify the 'a, the unsized trait object will default to 'static, meaning the Borrow trait can only be applied when the implementation like (&A, &B) outlives 'static, which is certainly not the case.

    3. Finally, we implement Eq and Hash. Same reason as point 2, we implement for dyn KeyPair + '_ instead of dyn KeyPair (which means dyn KeyPair + 'static in this context). The '_ here is a syntax sugar meaning arbitrary lifetime.


    Using trait objects will incur indirection cost when computing the hash and checking equality in get(). The cost can be eliminated if the optimizer is able to devirtualize that, but whether LLVM will do it is unknown.

    An alternative is to store the map as HashMap<(Cow<A>, Cow<B>), f64>. Using this requires less "clever code", but there is now a memory cost to store the owned/borrowed flag as well as runtime cost in both get() and set().

    Unless you fork the standard HashMap and add a method to look up an entry via Hash + Eq alone, there is no guaranteed-zero-cost solution.