c++templatescublas

How to use CUBLAS library within a template function?


CUBLAS has a separate function for each type of data, but I want to call CUBLAS from within a template, e.g.:

template <typename T> foo(...) {
    ...
    cublas<S/D/C/Z>geam(..., const T* A, ...);
    ...
}

How do I trigger the correct function call?


Solution

  • I wrote cublas wrapper functions for different types with same function name.

    inline cublasStatus_t cublasGgeam(cublasHandle_t handle,
            cublasOperation_t transa, cublasOperation_t transb,
            int m, int n,
            const float *alpha,
            const float *A, int lda,
            const float *B, int ldb,
            const float *beta,
            float *C, int ldc)
    {
        return cublasSgeam(handle, transa, transb, m, n, alpha, A, lda, B, ldb, beta, C, ldc);
    }
    
    inline cublasStatus_t cublasGgeam(cublasHandle_t handle,
            cublasOperation_t transa, cublasOperation_t transb,
            int m, int n,
            const double *alpha,
            const double *A, int lda,
            const double *B, int ldb,
            const double *beta,
            double *C, int ldc)
    {
        return cublasDgeam(handle, transa, transb, m, n, alpha, A, lda, B, ldb, beta, C, ldc);
    }
    

    After that, you can call geam() for any type with the same function name. C++ compiler will choose the right function by the type of the parameters. In you case it should be like

    template <typename T> foo(...) {
        ...
        cublasGgeam(..., A, ...);
        ...
    }
    

    This is a comple-time overload and no runtime cost at all, although you have to write a long list for wrapper functions.