webgpuprefix-sum

WebGPU prefix-sum: issue with bind-group ping-pong


I'm implementing a simple Hillis and Steele prefix-sum algorithm in WebGPU.

It deliberately doesn't do anything fancy. I'm working with small arrays so that I can dispatch as many work-groups as there are entries (just for simplicity). For each iteration of the prefix-sum algorithm, I'm ping-pong'ing between two bind-groups, so that the output of one iteration ends up being the input to the next.

export async function prefix_sum(input: number[]) {
  const adapter = await navigator.gpu?.requestAdapter();
  const device = await adapter?.requestDevice();
  if (!device) throw new Error(`No support for WebGPU`);

  /********************************************************************
   * Shader
   ********************************************************************/

  const shader = /*wgsl*/ `
        @group(0) @binding(0) var<storage, read> input: array<f32>;
        @group(0) @binding(1) var<storage, read_write> output: array<f32>;
        @group(0) @binding(2) var<uniform> iteration: u32;

        @compute @workgroup_size(1) fn main(
            @builtin(workgroup_id) workgroup_id : vec3<u32>,
            @builtin(global_invocation_id) global_invocation_id: vec3<u32>,
            @builtin(local_invocation_id) local_invocation_id: vec3<u32>
        ) {
            let i = iteration;
            let j = global_invocation_id.x;

            let p = u32(pow(2.0, f32(i)));
            if (j < p) {
                output[j] = input[j];
            } else {
                output[j] = input[j] + input[j - p];
            }
        }
    `;

  /********************************************************************
   * Data
   ********************************************************************/

  const l = input.length;
  const iterations = Math.floor(Math.log2(l));

  const iterationsData = new Uint32Array([iterations]);
  const iterationBuffer = device.createBuffer({
    label: `prefix sum iterations count buffer`,
    size: iterationsData.byteLength,
    usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
  });
  device.queue.writeBuffer(iterationBuffer, 0, iterationsData);

  const inputData = new Float32Array(input);
  const inputBuffer = device.createBuffer({
    label: `prefix sum input buffer`,
    size: inputData.byteLength,
    usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC,
  });
  device.queue.writeBuffer(inputBuffer, 0, inputData);

  const outputBuffer = device.createBuffer({
    label: `prefix sum output buffer`,
    size: inputData.byteLength,
    usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC,
  });

  const readBuffer = device.createBuffer({
    label: 'prefix sum read buffer',
    size: inputData.byteLength,
    usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST,
  });

  const module = device.createShaderModule({
    code: shader,
  });

  const pipeline = device.createComputePipeline({
    layout: 'auto',
    compute: {
      module,
    },
  });

  const bindGroup1 = device.createBindGroup({
    label: `prefix sum bind group 1`,
    layout: pipeline.getBindGroupLayout(0),
    entries: [
      { binding: 0, resource: { buffer: inputBuffer } },
      { binding: 1, resource: { buffer: outputBuffer } },
      { binding: 2, resource: { buffer: iterationBuffer } },
    ],
  });

  const bindGroup2 = device.createBindGroup({
    label: `prefix sum bind group 2`,
    layout: pipeline.getBindGroupLayout(0),
    entries: [
      { binding: 0, resource: { buffer: outputBuffer } },
      { binding: 1, resource: { buffer: inputBuffer } },
      { binding: 2, resource: { buffer: iterationBuffer } },
    ],
  });

  /********************************************************************
   * Execute
   ********************************************************************/

  const encoder = device.createCommandEncoder({ label: `prefix sum encoder` });

  for (let iteration = 0; iteration <= iterations; iteration++) {
    device.queue.writeBuffer(iterationBuffer, 0, new Uint32Array([iteration]));
    const pass = encoder.beginComputePass({ label: `prefix sum pass iteration ${iteration}` });
    pass.setPipeline(pipeline);
    if (iteration % 2 === 0) {
      pass.setBindGroup(0, bindGroup1);
    } else {
      pass.setBindGroup(0, bindGroup2);
    }
    pass.dispatchWorkgroups(input.length);
    pass.end();
  }

  /********************************************************************
   * Output
   ********************************************************************/

  const lastOutputBuffer = iterations % 2 === 0 ? inputBuffer : outputBuffer;
  encoder.copyBufferToBuffer(lastOutputBuffer, 0, readBuffer, 0, readBuffer.size);
  const commandBuffer = encoder.finish();
  device.queue.submit([commandBuffer]);

  await readBuffer.mapAsync(GPUMapMode.READ);
  const result = new Float32Array(readBuffer.getMappedRange());
  //   readBuffer.unmap();

  return structuredClone(result);
}

I'm observing some weird behaviour, though.

For the input [1, 2, 3, 4, 5, 6, 7, 8], I'm getting:

I tried a few things that I thought might have caused my issue.


Solution

  • I don't know if this is the only issue but one issue I see is you need a iterationBuffer per iteration.

    As it is, this code

      for (let iteration = 0; iteration < iterations; iteration++) {
        device.queue.writeBuffer(iterationBuffer, 0, new Uint32Array([iteration]));
        const pass = encoder.beginComputePass({ label: `prefix sum pass iteration ${iteration}` });
        pass.setPipeline(pipeline);
        if (iteration % 2 === 0) {
          pass.setBindGroup(0, bindGroup1);
        } else {
          pass.setBindGroup(0, bindGroup2);
        }
        pass.dispatchWorkgroups(input.length);
        pass.end();
      }
    
      ...
      device.queue.submit(...)
    

    Is effectively doing this

      const interationBuffer = [];
      for (let iteration = 0; iteration < iterations; iteration++) {
        iterationBuffer[0] = iteration
        encode commands
      }
    
      execute (submit) commands
    

    At the execute commands point above, iterationBuffer only contains the last value you wrote into it. If you want a different value for iteration each time you call dispatchWorkgroup in the same submit you need a different buffer to hold each value.