I'm trying to use cublasSgemmStridedBatched
in C++ to compute batched matrix multiplications between two sets of matrices (inputs x1 and x2), but I am struggling to match the expected output. I cannot figure out what the parameters should be. I must have tried every possible combination.
Here is my setup (both are stored in row major format)
x1 shape: (B, 1, M) — batch size B, 1 row, M columns.
x2 shape: (B, R, M) — batch size B, R rows, M columns.
I want to compute x1.(x2)^T for each batch
Example inputs :
B = 2, M = 3, R = 2
x1 = [ [1 2 1],
[0 0 1] ]
x2 = [[[1 1 1],
[2 2 1]],
[0 1 2],
[1 1 1]]]
expected output = [ [4 7], [2 1]]
Using this call
cublasSgemmStridedBatched(
handle,
CUBLAS_OP_T, CUBLAS_OP_N,
M, 1, M,
&alpha,
x2, M, strideB,
x1, M, strideA,
&beta,
output, M, strideC,
B
)
with:
strideA = M
strideB = R * M
strideC = R
alpha = 1.0f
beta = 0.0f
gives me [[4, 7], [4, 1]]
.
Let's look at the documentation. The declaration of the function looks like this:
cublasStatus_t cublasSgemmStridedBatched(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m, int n, int k,
const float *alpha,
const float *A, int lda,
long long int strideA,
const float *B, int ldb,
long long int strideB,
const float *beta,
float *C, int ldc,
long long int strideC,
int batchCount);
The meaning of dimension parameters m,n,k
is that opA(A)
is m x k
and opB(B)
is k x n
, C
is m x n
(where opA
and opB
are operations corresponding to transa
and transb
arguments respectively). Unless you use submatrices, lda
should be m
or k
(if opA
isn't or is a transpose, respectively), ldb
- k
or n
(if opB
isn't or is a transpose, respectively) and ldc
should be m
. Since your matrices are stored row-major and cublas expects column-major order, they are interpreted as transposed (assuming correct dimensional and ld
arguments). So, you have M x 1
(x1
) and M x R
(x2
) matrices from the point of view of the function. Now, there are 2 options:
x1^T * x2
to get 1 x R
result (which correspnds to row-major order R x 1
). Here m = 1, n = R, k = M, lda = M, ldb = M, ldc = 1
:cublasSgemmStridedBatched(
handle,
CUBLAS_OP_T, CUBLAS_OP_N,
1, R, M,
&alpha,
x1, M, M,
x2, M, R * M,
&beta,
output, 1, R,
B
)
x2^T * x1
to get R x 1
result (which correspnds to row-major order 1 x R
). Here m = R, n = 1, k = M, lda = M, ldb = M, ldc = R
. This is what you tried to do, it seems, but with wrong m
and ldc
:cublasSgemmStridedBatched(
handle,
CUBLAS_OP_T, CUBLAS_OP_N,
R, 1, M,
&alpha,
x2, M, R * M,
x1, M, M,
&beta,
output, R, R,
B
)
Note that the results of these two options are, logically, transposes of one another, but they are stored the same way because one of the dimensions is 1 (and row/column-major representations in this case are also the same). How you interpret them in further computations is up to you (e.g. corresponding dimension parameters for cublas functions).