cudanvidiatensorptxcuda-wmma

How are fp6 and fp4 supported on NVIDIA Tensor Core on Blackwell?


I am writing PTX assembly code on CUDA C++ for research. This is my setup:

The problem is, I am trying to write PTX code for Tensor Core on Blackwell, and the goal is to perform matrix multiplication for multiplicands of FP8E4M3, FP8E5M2, FP6E3M2, FP6E2M3 and for FP4E2M1.

Firstly, the compilation does succeed for mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 which is supported by pre-Blackwell generations, around Hopper and Ada, though I have not tried running it but at least the compiler doesn't output error for it.

Now I want to write the similar code for the remaining FP6 and FP4, and according to PTX ISA documentation (section 9.7.14.5.14), the instruction format is as follows:

mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32        d, a, b, c;
mma.sync.aligned.m16n8k8.row.col.f32.atype.btype.f32      d, a, b, c;
mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32       d, a, b, c;
mma.sync.aligned.shape.row.col.dtype.f8type.f8type.ctype  d, a, b, c;
mma.sync.aligned.m16n8k32.row.col.kind.dtype.f8f6f4type.f8f6f4type.ctype d, a, b, c;

.atype      = {.bf16, .tf32};
.btype      = {.bf16, .tf32};
.f8type     = {.e4m3, .e5m2};
.f8f6f4type = {.e4m3, .e5m2, .e3m2, .e2m3, .e2m1};
.ctype      = {.f16, .f32};
.dtype      = {.f16, .f32};
.shape      = {.m16n8k16, .m16n8k32};
.kind       = {.kind::f8f6f4};

And this is my kernel code (it only does <<<1,32>>> so I don't need a Warp index):

__global__ void wmma_gemm_kernel_m16n8k32_8bit_fp32(const __nv_fp8_e4m3 *A, const __nv_fp8_e4m3 *B, float *C) {
    // Manual fragment declaration
    unsigned int a_regs[4]; // 16x FP8 values for A (16 elements/thread)
    unsigned int b_regs[2]; // 8x FP8 values for B (8 elements/thread)
    float        c_regs[4]; // 4x float accumulator (C)
    float        d_regs[4]; // 4x float result (D)

    const unsigned int tid = threadIdx.x;
    const unsigned int lane_id = tid & 31;  // Warp lane index

    // Get pointers to start of warp tiles
    const unsigned int* pA = reinterpret_cast<const unsigned int*>(A);
    const unsigned int* pB = reinterpret_cast<const unsigned int*>(B);
    float* pC = C;

    // Corrected fragment loading
    // Matrix A (16x32 FP8, row-major)
    const unsigned int row_a = lane_id / 2;        // 0-15 (16 rows)
    const unsigned int seg_a = lane_id % 2;        // 0-1 (2 segments/row)
    const unsigned int a_base = row_a * 8 + seg_a * 4; // 8 registers/row
    a_regs[0] = pA[a_base];
    a_regs[1] = pA[a_base + 1];
    a_regs[2] = pA[a_base + 2];
    a_regs[3] = pA[a_base + 3];

    // Matrix B (32x8 FP8, col-major)
    const unsigned int col_b = lane_id / 4;    // 0-7 (8 columns)
    const unsigned int seg_b = lane_id % 4;    // 0-3 (4 segments/column)
    const unsigned int b_base = col_b * 8 + seg_b * 2;  // 32 = column stride
    b_regs[0] = pB[b_base];
    b_regs[1] = pB[b_base + 1];

    // Matrix C (16x8 float, row-major)
    const unsigned int row_c = lane_id / 2;        // 0-15 (16 rows)
    const unsigned int seg_c = lane_id % 2;        // 0-1 (2 segments/row)
    const unsigned int c_base = row_c * 8 + seg_c * 4; // 8 floats/row
    c_regs[0] = pC[c_base];
    c_regs[1] = pC[c_base + 1];
    c_regs[2] = pC[c_base + 2];
    c_regs[3] = pC[c_base + 3];

    // Execute PTX MMA instruction
    asm volatile(
        "mma.sync.aligned.m16n8k32.row.col.kind::f8f6f4.f32.e4m3.e4m3.f32 "
        "{%0, %1, %2, %3}, "    // D matrix (output)
        "{%4, %5, %6, %7}, "    // A matrix (input)
        "{%8, %9}, "             // B matrix (input)
        "{%10, %11, %12, %13};" // C matrix (input)
        : "=f"(d_regs[0]), "=f"(d_regs[1]), "=f"(d_regs[2]), "=f"(d_regs[3])
        : "r"(a_regs[0]), "r"(a_regs[1]), "r"(a_regs[2]), "r"(a_regs[3]),
          "r"(b_regs[0]), "r"(b_regs[1]),
          "f"(c_regs[0]), "f"(c_regs[1]), "f"(c_regs[2]), "f"(c_regs[3])
    );

    // Store results using same layout as C
    pC[c_base]     = d_regs[0];
    pC[c_base + 1] = d_regs[1];
    pC[c_base + 2] = d_regs[2];
    pC[c_base + 3] = d_regs[3];
}

The combination of FPx-Multiply (for x being 8, 6 and 4) and FP32-Accumulate is allowed according to the same PTX ISA documentation.

But the compiler error is (compilation command: nvcc cuda_test.cu -o cuda_test -arch=sm_121a):

ptxas /tmp/tmpxft_00004673_00000000-7_cuda_test.compute_121.ptx, line 53; error   : Feature '.kind::f8f6f4' not supported on .target 'sm_121'

I have also tried for FP6 and FP4 the same asm volatile integration, and:

ptxas /tmp/tmpxft_0000426d_00000000-7_cuda_test.compute_121.ptx, line 59; error   : Instruction 'mma with with FP6/FP4 floating point type' not supported on .target 'sm_121'
ptxas /tmp/tmpxft_0000426d_00000000-7_cuda_test.compute_121.ptx, line 59; error   : Feature '.kind::f8f6f4' not supported on .target 'sm_121'
ptxas /tmp/tmpxft_0000426d_00000000-7_cuda_test.compute_121.ptx, line 62; error   : Instruction 'mma with with FP6/FP4 floating point type' not supported on .target 'sm_121'
ptxas /tmp/tmpxft_0000426d_00000000-7_cuda_test.compute_121.ptx, line 62; error   : Feature '.kind::f8f6f4' not supported on .target 'sm_121'
ptxas /tmp/tmpxft_0000426d_00000000-7_cuda_test.compute_121.ptx, line 65; error   : Instruction 'mma with with FP6/FP4 floating point type' not supported on .target 'sm_121'
ptxas /tmp/tmpxft_0000426d_00000000-7_cuda_test.compute_121.ptx, line 65; error   : Feature '.kind::f8f6f4' not supported on .target 'sm_121'

As I described earlier, I have tried compiling first without a .kind::f8f6f4 because PTX allows not having to specify it for FP8, and it succeeded. But specifying it should also be allowed according to their documentation. Also I do need to specify it when it comes to FP6 and FP4 so it is a crucial problem.


Solution

  • The problem was related to the compiler.

    When specifying the targeted platform of the compilation, one must specify both the virtual target and the real target. The former is related to the PTX version, and the latter, to the SASS version. So, what might happened is that the compiler has defaulted to an old PTX version when I only specified -arch in the compilation command:

    nvcc cuda_test.cu -o cuda_test -arch=sm_121a
    

    Now that I also specify the -code, the compilation has succeeded. The command is as follows:

    nvcc cuda_test.cu -o cuda_test --generate-code arch=compute_121a,code=sm_121a
    

    In which the arch and code are parameters of the --generate-code.

    The compute_XX is about the virtual target, and the sm_XX, the real target.