assemblycudanvidiaptxtriton

The meaning of brackets around register in PTX assembly loads/stores


Below is an apparently legitimate PTX assembly code produced by Triton compiler. I'm puzzled by { %r1 } and { %r2 } used in load and store instructions. According to the PTX ISA documentation, it looks like an initializer list. But it does not make sense. Not only because initializers spec does not mention the use of registers. Not even because initializers within the load/store semantics are useless (there is nothing to initialize). Most of all I'm confused by the fact that the use of {} in loads/stores changes the meaning of argument from scalar to pointer immediate.

Perhaps, a bored dev just wanted to make everyone's assembly experience more confusing. Does anyone have any better explanation?

.version 7.5
.target sm_35
.address_size 64

        // .globl       E__01

.visible .entry E__01(
        .param .u64 E__01_param_0,
        .param .u64 E__01_param_1
)
.maxntid 128, 1, 1
{
        .reg .pred      %p<3>;
        .reg .b32       %r<4>;
        .reg .b64       %rd<3>;
        .loc    1 6 0
$L__func_begin0:
        .loc    1 6 0

        ld.param.u64    %rd2, [E__01_param_0];
        ld.param.u64    %rd1, [E__01_param_1];
        mov.pred        %p1, -1;
$L__tmp0:
        .loc    1 7 19
        mov.u32 %r1, 0x0;
        @%p1 ld.global.b32 { %r1 }, [ %rd1 + 0 ];
        .loc    1 8 18
        shl.b32         %r2, %r1, 1;
        .loc    1 9 22
        mov.u32         %r3, %tid.x;
        setp.eq.s32     %p2, %r3, 0;
        @%p2 st.global.b32 [ %rd2 + 0 ], { %r2 };
        .loc    1 9 2
        ret;
$L__tmp1:
$L__func_end0:

}

Solution

  • This syntax seems to be a valid generalization of load vector and unpack or pack vector and store. Basically it treats a single register as a vector of size 1.

    The same syntax with a proper vector can be seen in an example like this:

    __global__ void foo(float2* arr, int n) {
        int tid = blockDim.x * blockIdx.x + threadIdx.x;
        if (tid < n) {
            float2 x = arr[tid];
            arr[tid] = make_float2(x.x * x.y, x.x + x.y);
        }
    }
    

    The inside of that if compiles to

            ld.global.v2.f32        {%f1, %f2}, [%rd4];
            mul.f32         %f3, %f1, %f2;
            add.f32         %f4, %f1, %f2;
            st.global.v2.f32        [%rd4], {%f3, %f4};