operator-overloadingrustfunction-call-operator

How do I implement the Fn trait for one struct for different types of arguments?


I have a simple classifier:

struct Clf {
    x: f64,
}

The classifier returns 0 if the observed value is smaller than x and 1 if bigger than x.

I want to implement the call operator for this classifier. However, the function should be able to take either a float or a vector as arguments. In case of a vector, the output is a vector of 0 or 1 which has the same size as the input vector:

let c = Clf { x: 0 };
let v = vec![-1, 0.5, 1];
println!("{}", c(0.5));   // prints 1
println!("{}", c(v));     // prints [0, 1, 1]

How can I write implementation of Fn in this case?

impl Fn for Clf {
    extern "rust-call" fn call(/*...*/) {
        // ...
    }
}

Solution

  • The short answer is: You can't. At least it won't work the way you want. I think the best way to show that is to walk through and see what happens, but the general idea is that Rust doesn't support function overloading.

    For this example, we will be implementing FnOnce, because Fn requires FnMut which requires FnOnce. So, if we were to get this all sorted, we could do it for the other function traits.

    First, this is unstable, so we need some feature flags

    #![feature(unboxed_closures, fn_traits)]
    

    Then, let's do the impl for taking an f64:

    impl FnOnce<(f64,)> for Clf {
        type Output = i32;
        extern "rust-call" fn call_once(self, args: (f64,)) -> i32 {
            if args.0 > self.x {
                1
            } else {
                0
            }
        }
    }
    

    The arguments to the Fn family of traits are supplied via a tuple, so that's the (f64,) syntax; it's a tuple with just one element.

    This is all well and good, and we can now do c(0.5), although it will consume c until we implement the other traits.

    Now let's do the same thing for Vecs:

    impl FnOnce<(Vec<f64>,)> for Clf {
        type Output = Vec<i32>;
        extern "rust-call" fn call_once(self, args: (Vec<f64>,)) -> Vec<i32> {
            args.0
                .iter()
                .map(|&f| if f > self.x { 1 } else { 0 })
                .collect()
        }
    }
    

    Before Rust 1.33 nightly, you cannot directly call c(v) or even c(0.5) (which worked before); we'd get an error about the type of the function not being known. Basically, these versions of Rust didn't support function overloading. But we can still call the functions using fully qualified syntax, where c(0.5) becomes FnOnce::call_once(c, (0.5,)).


    Not knowing your bigger picture, I would want to solve this simply by giving Clf two functions like so:

    impl Clf {
        fn classify(&self, val: f64) -> u32 {
            if val > self.x {
                1
            } else {
                0
            }
        }
    
        fn classify_vec(&self, vals: Vec<f64>) -> Vec<u32> {
            vals.into_iter().map(|v| self.classify(v)).collect()
        }
    }
    

    Then your usage example becomes

    let c = Clf { x: 0 };
    let v = vec![-1, 0.5, 1];
    println!("{}", c.classify(0.5));   // prints 1
    println!("{}", c.classify_vec(v)); // prints [0, 1, 1]
    

    I would actually want to make the second function classify_slice and take &[f64] to be a bit more general, then you could still use it with Vecs by referencing them: c.classify_slice(&v).