c++templatescudamixed-mode

How to transfer datatype declaration from .cpp file to .cu file?


I found that cuda support use "template" keyword for the code, now I would like to link the usage of "template" between nvcc and g++. But it seems that I cannot find a proper way to implement it, so I use the string of datatype to deliever the datatype declaration. Could I find a better method to do it?

//in .cpp

extern "C" void function(string T);

int main(){
    function("float");
}

//in .cu

extern "C" void function(string T){
    if(T == "short")
        func<short>(...);
    if(T == "int")
        func<int>(...);
    .......
}

Solution

  • You could use C-style function overloading.

    // in .cpp
    
    extern "C" void func_short();
    extern "C" void func_int();
    extern "C" void func_float();
    
    int main(){
      func_float();
    }
    
    // in .cu
    
    template <typename T>
    void func() {
      // ...
    }
    
    extern "C" void func_short() {
      func<short>();
    }
    
    extern "C" void func_int() {
      func<int>();
    }
    
    extern "C" void func_float() {
      func<float>();
    }
    

    This is significantly faster than comparing strings every time you call the function. If you wanted to, you could create a wrapper template function on the C++ side to make the usage a bit cleaner.

    // in .cpp
    
    extern "C" void func_short();
    extern "C" void func_int();
    extern "C" void func_float();
    
    template <typename T>
    void func() = delete;
    
    template <>
    void func<short>() {
      func_short();
    }
    
    template <>
    void func<int>() {
      func_int();
    }
    
    template <>
    void func<float>() {
      func_float();
    }
    
    int main(){
      func<float>();
    }
    

    To make maintenance a little easier, you could define some macros.

    // in .cpp
    
    template <typename T>
    void func() = delete;
    
    #define DECLARE_FUNC(TYPE)                                                      \
      extern "C" void func_##TYPE();                                                \
      template <>                                                                   \
      void func<TYPE>() {                                                           \
        func_##TYPE();                                                              \
      }                                                                             \
    
    DECLARE_FUNC(short)
    DECLARE_FUNC(int)
    DECLARE_FUNC(float)
    
    int main(){
      func<float>();
    }
    
    //in .cu
    
    template <typename T>
    void func() {
      // ...
    }
    
    #define DECLARE_FUNC(TYPE)                                                      \
      extern "C" void func_##TYPE() {                                               \
        func<TYPE>();                                                               \
      }
    
    DECLARE_FUNC(short)
    DECLARE_FUNC(int)
    DECLARE_FUNC(float)
    

    You could put those DECLARE_FUNC lines in a common header so that you only have to update the list in one place. If you wanted to add a double function, you could just add DECLARE_FUNC(double) to the header.

    // in declare_func.hpp
    
    DECLARE_FUNC(short)
    DECLARE_FUNC(int)
    DECLARE_FUNC(float)
    
    // in .cpp
    
    template <typename T>
    void func() = delete;
    
    #define DECLARE_FUNC(TYPE)                                                      \
      extern "C" void func_##TYPE();                                                \
      template <>                                                                   \
      void func<TYPE>() {                                                           \
        func_##TYPE();                                                              \
      }                                                                             \
    
    #include "declare_func.hpp"
    
    int main(){
      func<float>();
    }
    
    //in .cu
    
    template <typename T>
    void func() {
      // ...
    }
    
    #define DECLARE_FUNC(TYPE)                                                      \
      extern "C" void func_##TYPE() {                                               \
        func<TYPE>();                                                               \
      }
    
    #include "declare_func.hpp"
    

    I've gone from easy-to-setup to easy-to-maintain. You'll have to decide what is appropriate for your situation.