metal

Metal - free function to get `[[thread_position_in_grid]]` values


In CUDA the thread index is obtained (for 1D) by blockDim.x*blockIdx.x+threadIdx.x. In OpenCL it is obtained by get_global_id(0). Both of these result in calls to (intrinsic/magic) functions in LLVM (llvm.nvvm.read.ptx.sreg.tid.x and friends for CUDA, and _Z13get_global_idj for OpenCL).

Metal uses a magic attribute: kernel void foo( uint arg [[thread_position_in_grid]] ), which then through some magic from the compiler (attributes and metadata) indicates to the Metal runtime that arg should have the value of the thread position.

I'm trying to extend my compiler for OpenCL and CUDA to Metal, so I want to have Metal behave like OpenCL and CUDA, and don't want to change the function signature of my existing kernels.

My question: Is it possible to write a free function (in LLVM IR) that returns the value of the thread position and other variables associated with the other magic attributes?


Solution

  • What you are looking for is called Program Scope Global Built-ins and Bindings. You can read about them in section 5.9 of Metal Shading Language Specification. It shows exactly the example that you want:

    uint2 gid [[thread_position_in_grid]];
    float4 get_color(texture2d<float> texInput, sampler s) {
        return texInput.sample(s, float2(gid));
    }
    [[kernel]] void my_kernel(texture2d<float> texInput, sampler s, ...) {
       auto color = get_color(texInput, s); 
       ...
    }
    

    Basically, just declare your variable as a global and then add the attribute to that.