functionc++11lambdavariadic-functionsarity

C++11: how to accept user-supplied functions elegantly?


I am trying to offer users of my library to supply their own calculation functions, that need to maintain the rules:

I got a solution - see below - but it is bulky and ugly. I would bet there is a more elegant way to solve it with templates and type_traits, but I cannot wrap my head around it.

// Copyright 2022 bla 
#include <cstdint>
#include <iostream>
#include <functional>
#include <math.h>

using std::cout;
using std::endl;

class D {
public:
  using Dlambda0 = std::function<float()>;
  using Dlambda1 = std::function<float(float)>;
  using Dlambda2 = std::function<float(float, float)>;
  using Dlambda3 = std::function<float(float, float, float)>;
  using Dlambda4 = std::function<float(float, float, float, float)>;
  using Dlambda5 = std::function<float(float, float, float, float, float)>;
  using Dlambda6 = std::function<float(float, float, float, float, float, float)>;
  using Dlambda7 = std::function<float(float, float, float, float, float, float, float)>;

  D(const char *name, Dlambda0 dFn) : n(name), fn0 {std::move(dFn)} { cout << name << ":" << __PRETTY_FUNCTION__ << endl; arity = 0; }
  D(const char *name, Dlambda1 dFn) : n(name), fn1 {std::move(dFn)} { cout << name << ":" << __PRETTY_FUNCTION__ << endl; arity = 1; }
  D(const char *name, Dlambda2 dFn) : n(name), fn2 {std::move(dFn)} { cout << name << ":" << __PRETTY_FUNCTION__ << endl; arity = 2; }
  D(const char *name, Dlambda3 dFn) : n(name), fn3 {std::move(dFn)} { cout << name << ":" << __PRETTY_FUNCTION__ << endl; arity = 3; }
  D(const char *name, Dlambda4 dFn) : n(name), fn4 {std::move(dFn)} { cout << name << ":" << __PRETTY_FUNCTION__ << endl; arity = 4; }
  D(const char *name, Dlambda5 dFn) : n(name), fn5 {std::move(dFn)} { cout << name << ":" << __PRETTY_FUNCTION__ << endl; arity = 5; }
  D(const char *name, Dlambda6 dFn) : n(name), fn6 {std::move(dFn)} { cout << name << ":" << __PRETTY_FUNCTION__ << endl; arity = 6; }
  D(const char *name, Dlambda7 dFn) : n(name), fn7 {std::move(dFn)} { cout << name << ":" << __PRETTY_FUNCTION__ << endl; arity = 7; }

  float operator() () { return callForArity(0, nanf(""), nanf(""), nanf(""), nanf(""), nanf(""), nanf(""), nanf("")); }
  float operator() (float a) { return callForArity(1, a, nanf(""), nanf(""), nanf(""), nanf(""), nanf(""), nanf("")); }
  float operator() (float a, float b) { return callForArity(2, a, b, nanf(""), nanf(""), nanf(""), nanf(""), nanf("")); }
  float operator() (float a, float b, float c) { return callForArity(3, a, b, c, nanf(""), nanf(""), nanf(""), nanf("")); }
  float operator() (float a, float b, float c, float d) { return callForArity(4, a, b, c, d, nanf(""), nanf(""), nanf("")); }
  float operator() (float a, float b, float c, float d, float e) { return callForArity(5, a, b, c, d, e, nanf(""), nanf("")); }
  float operator() (float a, float b, float c, float d, float e, float f) { return callForArity(6, a, b, c, d, e, f, nanf("")); }
  float operator() (float a, float b, float c, float d, float e, float f, float g) { return callForArity(7, a, b, c, d, e, f, g); }

protected:
  const char *n;
  Dlambda0 fn0;
  Dlambda1 fn1;
  Dlambda2 fn2;
  Dlambda3 fn3;
  Dlambda4 fn4;
  Dlambda5 fn5;
  Dlambda6 fn6;
  Dlambda7 fn7;
  uint8_t arity;

  float callForArity(uint8_t givenArgs, float a, float b, float c, float d, float e, float f, float g) {
    switch (arity) {
    case 0:
      return fn0();
      break;
    case 1:
      return fn1(a);
      break;
    case 2:
      return fn2(a, b);
      break;
    case 3:
      return fn3(a, b, c);
      break;
    case 4:
      return fn4(a, b, c, d);
      break;
    case 5:
      return fn5(a, b, c, d, e);
      break;
    case 6:
      return fn6(a, b, c, d, e, f);
      break;
    case 7:
      return fn7(a, b, c, d, e, f, g);
      break;
    }
    return nanf("");
  }
};

float fix1(float a) {
  return a * a;
}

int main(int argc, char **argv) {

  // register some functions
  D function("bla", { []() { return 5; }});
  D function2("blubb", { [](float x, float y) { return x + y; }});
  D function3("Blort", fix1);
  
  cout << function() << endl;                     // Okay
  cout << function2(1.0, 5.5) << endl;            // Okay
  cout << function2(5.5) << endl;                 // Aw!
  cout << function3(5.5) << endl;                 // okay

  return 0;
}

Solution

  • You have the exact right idea with sizeof... and static_assert. We can have D take a parameter pack and build the function type from that.

    template <typename... Ts>
    class D {
    public:
      using function_type = std::function<float(Ts...)>;
      ...
    };
    

    Simply enough. Now the assertions. These just go in the class body. It sounds like you mostly already figured out the sizeof... part.

    static_assert(sizeof...(Ts) <= 7, "Too many arguments");
    

    Now, for the float, it sounds like you want all of the Ts to be literally float (i.e. is_same). If "is convertible to float" is strong enough for your use case, you might consider replacing is_same with is_convertible in these examples. Regardless, you tagged C++11, but if we have C++17, we can write this assertion very easily with fold expressions.

    static_assert((std::is_same<Ts, float>::value && ...), "All args must be floats");
    

    If you don't have C++17, we can use @Columbo's excellent all_true trick.

    template <bool...>
    struct bool_pack;
    
    template <bool... v>
    using all_true = std::is_same<bool_pack<true, v...>, bool_pack<v..., true>>;
    
    static_assert(all_true<std::is_same<Ts, float>::value...>::value, "All args must be floats");
    

    Now we can construct functions by specifying the number of floats they take

    D<> function("bla", []() { return 5; });
    D<float, float> function2("blubb", [](float x, float y) { return x + y; });
    D<float> function3("Blort", fix1);
    

    and we can call them

    std::cout << function() << std::endl;
    std::cout << function2(1.0, 5.5) << std::endl;
    std::cout << function3(5.5) << std::endl;
    

    If we try function2(5.5), we get.... not a bad error message, by C++ template standards

    so_vv.cpp: In function ‘int main()’:
    so_vv.cpp:47:29: error: no match for call to ‘(D<float, float>) (double)’
       47 |   std::cout << function2(5.5) << std::endl;               // Aw!
          |                             ^
    so_vv.cpp:30:9: note: candidate: ‘float D<Ts>::operator()(Ts ...) [with Ts = {float, float}]’
       30 |   float operator()(Ts... args) {
          |         ^~~~~~~~
    so_vv.cpp:30:9: note:   candidate expects 2 arguments, 1 provided
    

    And we can fail our other assertions.

    D<int> not_allowed("foobar", [](int x) { return 10; });
    D<float, float, float, float, float, float, float, float, float> not_allowed2("foobar", [](...) { return 10; });
    

    These will fail our static_assert calls at compile time.

    If you don't like the explicit D<float, float, float> template arguments, you can definitely get it down to D<3> with some clever use of std::integer_sequence. As for whether or not you can get C++ to completely infer the argument count, I'm not sure. I didn't have good luck getting type inference to go that far, but you might have better luck.

    Try it online!