metal

Performing a reduce operation with Metal


The Metal Shading Language Specification (PDF) includes a piece of sample code for performing a parallelized reduction (specifically, a summation) over an input array:

#error /!\ READER BEWARE - CONTAINS BUGS - READ ANSWER /!\

#include <metal_stdlib>

using namespace metal;


kernel void
reduce(const device int *input [[buffer(0)]],
       device atomic_int *output [[buffer(1)]],
       threadgroup int *ldata [[threadgroup(0)]],
       uint gid [[thread_position_in_grid]],
       uint lid [[thread_position_in_threadgroup]],
       uint lsize [[threads_per_threadgroup]],
       uint simd_size [[threads_per_simdgroup]],
       uint simd_lane_id [[thread_index_in_simdgroup]],
       uint simd_group_id [[simdgroup_index_in_threadgroup]])
{
    // Perform the first level of reduction.
    // Read from device memory, write to threadgroup memory.
    int val = input[gid] + input[gid + lsize];  // BUG 1
    for (uint s=lsize/simd_size; s>simd_size; s/=simd_size)  // BUG 2
    {
        // Perform per-SIMD partial reduction.
        for (uint offset=simd_size/2; offset>0; offset/=2)
            val += simd_shuffle_down(val, offset);

        // Write per-SIMD partial reduction value to threadgroup memory.
        if (simd_lane_id == 0)
            ldata[simd_group_id] = val;

        // Wait for all partial reductions to complete.
        threadgroup_barrier(mem_flags::mem_threadgroup);

        val = (lid < s) ? ldata[lid] : 0;
    }

    // Perform final per-SIMD partial reduction to calculate
    // the threadgroup partial reduction result.
    for (uint offset=simd_size/2; offset>0; offset/=2)
        val += simd_shuffle_down(val, offset);

    // Atomically update the reduction result.
    if (lid == 0)
        atomic_fetch_add_explicit(output, val, memory_order_relaxed);
}

Unfortunately, the kernel seems to produce invalid results, and I have a hard time understanding how it is supposed to work. The comments are not very illuminating.

What is causing it to produce invalid results?

How can this be adapted to other operations aside from summation?

What considerations are there for selecting the grid or threadgroup size?


Solution

  • Recall that in Metal, a compute pass executes the kernel function in multiple threads which are organized into threadgroups, which in turn are subdivided into SIMD groups.

    To maximize throughput, we need to reduce at all 3 levels from smallest to largest.

    1. Within a SIMD group
    2. Within a threadgroup, across SIMD groups
    3. Across threadgroups

    Within a SIMD group

    This is the primary building block of our kernel. I'll borrow the illustration from the NVIDIA Technical Blog, showing a SIMD group size of 8:

    partial sum using log(n) SIMD shuffle down operations

    The "shuffle down" function, simd_shuffle_down() in Metal, allows a thread to read the value of a variable from another thread within the same SIMD group. By accumulating this with current thread's own partial sum, we reduce the number of summands by half. Thus it takes log_2(simd_size) steps to sum across the SIMD group.

    In Apple's example shader, that code is here:

    for (uint offset=simd_size/2; offset>0; offset/=2)
        val += simd_shuffle_down(val, offset);
    

    The final result ends up in the first lane, so only the first thread in the SIMD group can access it. Therefore, we condition any code that needs to read the SIMD group sum from val with if (simd_lane_id == 0).

    Of course, we need to have initialized each thread's val variable with a different value from the input array. Uh oh, here is APPLE'S BUG #1:

    int val = input[gid] + input[gid + lsize];  // no!
    int val = input[gid];  // yes
    

    Within a threadgroup, across SIMD groups

    A threadgroup is composed of multiple SIMD groups, so now that we have found the partial sum of each SIMD group, we must combine them.

    First we'll need a way for the threads to communicate the partial sums of their SIMD groups. In the kernel's parameter list, we declare a storage buffer shared amongst the threadgroup, threadgroup int *ldata [[threadgroup(0)]]. When we set up our compute pass, we provide this storage with MTLComputeCommandEncoder.setThreadgroupMemoryLength(_:index:).

    After performing the SIMD group reduction, we store the result to thread group memory. We have a threadgroup_barrier after this to wait until all SIMD groups' partial sums have been stored.

    // Write per-SIMD partial reduction value to threadgroup memory.
    if (simd_lane_id == 0)
        ldata[simd_group_id] = val;
    
    // Wait for all partial reductions to complete.
    threadgroup_barrier(mem_flags::mem_threadgroup);
    

    At this point, the buffer ldata contains all of the SIMD groups' partial sums, and now we need to sum those. We already have a parallel sum reduction for within a SIMD group, so why not reuse it? If we copy the data back out of ldata, we can just repeat the earlier step:

    gathering SIMD group partial sums via threadgroup storage

    The kernel function does so with this conditional copy. It uses the within-threadgroup index lid so that the first thread reads the partial sum of the first SIMD group, the second thread reads the partial sum of the second SIMD group, and so on.

    val = (lid < s) ? ldata[lid] : 0;
    

    A tricky piece is the conditional lid < s. Note how s is initialized to the quantity lsize/simd_size, i.e., the number of SIMD groups within our threadgroup. This condition serves to avoid accessing indices of ldata that were not populated.

    With the SIMD groups loaded with new summands, we allow the loop to jump back to the top to repeat the process:

    for (uint s=lsize/simd_size; s>simd_size; s/=simd_size)  // BUG 2
    {
        // compute partial sum within SIMD group
        // store partial sums in threadgroup storage ldata
        // copy ldata storage back to SIMD group
    }
    

    Each time, we reduce the number of summands by a factor of simd_size, so the idea is we continue until we're down to just a single SIMD group left. However this brings us to APPLE'S BUG #2. The variable s is defined as the number of SIMD groups to process, so the terminating condition should be s > 1! We don't want to end with simd_size groups remaining to be summed.

    for (uint s=lsize/simd_size; s>simd_size; s/=simd_size)  // no!
    for (uint s=lsize/simd_size; s>1; s/=simd_size)  // yes
    

    When the loop exits, we have one last within-SIMD group summation to perform:

    // Perform final per-SIMD partial reduction to calculate
    // the threadgroup partial reduction result.
    for (uint offset=simd_size/2; offset>0; offset/=2)
        val += simd_shuffle_down(val, offset);
    

    Now the first thread of the first SIMD group, which is also the first thread of the threadgroup (lid == 0), has the threadgroup partial sum in val.

    Across threadgroups

    Lastly, we need to take all of those partial sums known to the first threads of all the threadgroups and combine them.

    Apple's kernel does so with the atomic_fetch_add_explicit() function. This is a read-modify-write operation that will read the content of the output buffer, add val to it, and then write the sum back. The atomic qualifier assures us that all three substeps will complete without contending with race conditions. (Note that output has type atomic_int.) Since summation is commutative, the order in which the threads call this function also doesn't matter.

    Again, the lid == 0 condition is because only the first thread of the threadgroup has the threadgroup's partial sum.

    // Atomically update the reduction result.
    if (lid == 0)
        atomic_fetch_add_explicit(output, val, memory_order_relaxed);
    

    On threadgroup and SIMD group sizes

    There does not seem to be any API or written documentation that gives the number of threads per SIMD group (related question), but that number on all Apple devices (as of June 2024) is 32. (You can only get it programmatically within a shader by adding a parameter with the [[threads_per_simdgroup]] attribute, as we do here.)

    It's probably a good idea to have the threadgroup size be a multiple of the SIMD group size.

    Reduction operations other than sum

    It should be possible to use roughly the same algorithm to perform other kinds of parallel reductions.

    The summation operation val += simd_shuffle_down(val, offset) generalizes to val = op(val, simd_shuffle_down(val, offset)).

    However there are not many options for atomic_fetch_<op>_explicit(): just add, and, max, min, or, sub, and xor. If your intended operation is not one of these, you may need to use an output buffer of threadgroup partial reductions, and let the CPU do the final reduction step. You would add a parameter for the threadgroup's position within the grid, such as uint tgid [[threadgroup_position_in_grid]] and then use it to store the partial result:

    if (lid == 0)
        output[tgid] = val;