c++cudacublas

Using cublas<T>gemmStridedBatched with row major matrices


I'm trying to use cublasSgemmStridedBatched in C++ to compute batched matrix multiplications between two sets of matrices (inputs x1 and x2), but I am struggling to match the expected output. I cannot figure out what the parameters should be. I must have tried every possible combination.

Here is my setup (both are stored in row major format)

x1 shape: (B, 1, M) — batch size B, 1 row, M columns.

x2 shape: (B, R, M) — batch size B, R rows, M columns.

I want to compute x1.(x2)^T for each batch

Example inputs :

B = 2, M = 3, R = 2

x1 = [ [1 2 1],
       [0 0 1] ]

x2 = [[[1 1 1],
        [2 2 1]],
       [0 1 2],
        [1 1 1]]]

expected output = [ [4 7], [2 1]]

Using this call

    cublasSgemmStridedBatched(
        handle,
        CUBLAS_OP_T, CUBLAS_OP_N,
        M, 1, M,
        &alpha,
        x2, M, strideB,
        x1, M, strideA,
        &beta,
        output, M, strideC,
        B
    )

with:

strideA = M

strideB = R * M

strideC = R

alpha = 1.0f

beta = 0.0f

gives me [[4, 7], [4, 1]].


Solution

  • Let's look at the documentation. The declaration of the function looks like this:

    cublasStatus_t cublasSgemmStridedBatched(cublasHandle_t handle,
                                      cublasOperation_t transa,
                                      cublasOperation_t transb,
                                      int m, int n, int k,
                                      const float           *alpha,
                                      const float           *A, int lda,
                                      long long int          strideA,
                                      const float           *B, int ldb,
                                      long long int          strideB,
                                      const float           *beta,
                                      float                 *C, int ldc,
                                      long long int          strideC,
                                      int batchCount);
    

    The meaning of dimension parameters m,n,k is that opA(A) is m x k and opB(B) is k x n, C is m x n (where opA and opB are operations corresponding to transa and transb arguments respectively). Unless you use submatrices, lda should be m or k (if opA isn't or is a transpose, respectively), ldb - k or n (if opB isn't or is a transpose, respectively) and ldc should be m. Since your matrices are stored row-major and cublas expects column-major order, they are interpreted as transposed (assuming correct dimensional and ld arguments). So, you have M x 1 (x1) and M x R (x2) matrices from the point of view of the function. Now, there are 2 options:

    cublasSgemmStridedBatched(
            handle,
            CUBLAS_OP_T, CUBLAS_OP_N,
            1, R, M,
            &alpha,
            x1, M, M,
            x2, M, R * M,
            &beta,
            output, 1, R,
            B
        )
    
    cublasSgemmStridedBatched(
            handle,
            CUBLAS_OP_T, CUBLAS_OP_N,
            R, 1, M,
            &alpha,
            x2, M, R * M,
            x1, M, M,
            &beta,
            output, R, R,
            B
        )
    

    Note that the results of these two options are, logically, transposes of one another, but they are stored the same way because one of the dimensions is 1 (and row/column-major representations in this case are also the same). How you interpret them in further computations is up to you (e.g. corresponding dimension parameters for cublas functions).