c++crtplibtorch

Is a different template parameter required for return type deduction in CRTP?


I have a general understand of how some code works, but am not sure if parts are redundant or if I am just not getting why its done this way. This example is taken from here: https://github.com/pytorch/pytorch/blob/31f311a816c026bbfca622d6121d6a7fab44260d/torch/csrc/autograd/custom_function.h#L96

Consider the following CRTP scenario: I want to instantiate various functions which are being tracked elsewhere (i.e. capturing the params for creating a computation graph):

#include <iostream>

template <typename X, typename... Args>
using forward_t = decltype(X::forward(std::declval<Args>()...));

template <typename T>
struct Function {
    template <typename... Args>
    static auto apply(Args &&...args) {
        using forward_return_t = forward_t<T, Args...>;
        // Do something with params before calling forward ...
        forward_return_t output = T::forward(std::forward<Args>(args)...);
        return output;
    }
};

struct Addition : public Function<Addition> {
    static int forward(int lhs, int rhs) {
        return lhs + rhs;
    }
};

struct Subtraction : public Function<Subtraction> {
    static double forward(double lhs, double rhs) {
        return lhs - rhs;
    }
};

int main() {
    auto add_value = Addition::apply(10, 15);
    std::cout << add_value << std::endl;    // 25
    auto sub_value = Subtraction::apply(10.5, 15.0);
    std::cout << sub_value << std::endl;    // -4.5
}

This compiles (-std=c++14/17/20) and works as expected. However, the signature from the libtorch codebase is as such:

template <class T>
struct Function {
  // We need to use a different template parameter than T here because T will
  // inherit from Function, and when Function<T> is instantiated, T::forward
  // is not declared yet.
  // The enable_if check is to ensure that the user doesn't explicitly provide
  // the parameter X.
  template <typename X = T, typename... Args>
  static auto apply(Args&&... args)
      -> std::enable_if_t<std::is_same<X, T>::value, forward_t<X, Args...>>;
};

The reasoning (from my understanding) is that we need type T to know the value forward_t, so its part of the template type arguments in front.

I guess my question is 1) Is the added typename X = T redundant because my code without it seems to work just fine, and 2) why is a check with std::enable_if given if the code for forward_t would fail fine without it? I guess I can't think of a scenario where this is required, and I'm sure the extra code isn't added for no reason.

Thanks!


Solution

  • Thanks to @igor-tandetnik, the question was answered in the comments but I'll close it out here.

    Both methods are fine: It really comes down to whether a declaration but not a definition is required. If only a declaration is given, an additional template parameter is required to get the return type; but, if the definition is given, then the return type can be deduced and thus the additional template parameter is not required.