.netf#memoization

How can ConditionalWeakTable be used with multiple references?


ConditionalWeakTable represents a thread-safe way to attach references to other references and not have to worry about garbage collection of the attached references. However, it can only be used with a single reference as a key. How can I use multiple references as keys?

For example, I have a function that accepts three reference types as input and produces a value, and I want to memoize this function based on the references of the three inputs. How can I use ConditionalWeakTable for this?


Solution

  • You can achieve the desired functionality by nesting ConditionalWeakTable:

    module Memoize =
    
        /// Memoizes the specified function using reference equality on the input arguments.
        /// The result is cached for the lifetime of the keys.
        ///
        /// Don't call with additional arguments as ad-hoc tuples or records,
        /// since these will never be reference equal.
        let refEq3 (f: 'k1 -> 'k2 -> 'k3 -> 'v) =
            let cache = ConditionalWeakTable<'k1, ConditionalWeakTable<'k2, ConditionalWeakTable<'k3, 'v>>>()
    
            fun k1 k2 k3 ->
                let inner1 = cache.GetOrCreateValue(k1)
                let inner2 = inner1.GetOrCreateValue(k2)
    
                match inner2.TryGetValue(k3) with
                | true, v -> v
                | false, _ ->
                    let v = f k1 k2 k3
                    inner2.TryAdd(k3, v) |> ignore<bool>
                    v
    

    For completeness, here's a version that uses boxing/unboxing to also allow struct values. Boxing has some performance impact, but in the case of memoization, it is likely that the function you are memoizing has a significantly higher performance impact anyway.

    let refEq3 (f: 'k1 -> 'k2 -> 'k3 -> 'v) =
        let cache =
            ConditionalWeakTable<'k1, ConditionalWeakTable<'k2, ConditionalWeakTable<'k3, obj>>>()
    
        fun k1 k2 k3 ->
            let inner1 = cache.GetOrCreateValue(k1)
            let inner2 = inner1.GetOrCreateValue(k2)
    
            match inner2.TryGetValue(k3) with
            | true, v -> unbox<'v> v
            | false, _ ->
                let v = f k1 k2 k3
                inner2.TryAdd(k3, box v) |> ignore<bool>
                v