I'm implementing a simple Hillis and Steele prefix-sum algorithm in WebGPU.
It deliberately doesn't do anything fancy. I'm working with small arrays so that I can dispatch as many work-groups as there are entries (just for simplicity). For each iteration of the prefix-sum algorithm, I'm ping-pong'ing between two bind-groups, so that the output of one iteration ends up being the input to the next.
export async function prefix_sum(input: number[]) {
const adapter = await navigator.gpu?.requestAdapter();
const device = await adapter?.requestDevice();
if (!device) throw new Error(`No support for WebGPU`);
/********************************************************************
* Shader
********************************************************************/
const shader = /*wgsl*/ `
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
@group(0) @binding(2) var<uniform> iteration: u32;
@compute @workgroup_size(1) fn main(
@builtin(workgroup_id) workgroup_id : vec3<u32>,
@builtin(global_invocation_id) global_invocation_id: vec3<u32>,
@builtin(local_invocation_id) local_invocation_id: vec3<u32>
) {
let i = iteration;
let j = global_invocation_id.x;
let p = u32(pow(2.0, f32(i)));
if (j < p) {
output[j] = input[j];
} else {
output[j] = input[j] + input[j - p];
}
}
`;
/********************************************************************
* Data
********************************************************************/
const l = input.length;
const iterations = Math.floor(Math.log2(l));
const iterationsData = new Uint32Array([iterations]);
const iterationBuffer = device.createBuffer({
label: `prefix sum iterations count buffer`,
size: iterationsData.byteLength,
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
});
device.queue.writeBuffer(iterationBuffer, 0, iterationsData);
const inputData = new Float32Array(input);
const inputBuffer = device.createBuffer({
label: `prefix sum input buffer`,
size: inputData.byteLength,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC,
});
device.queue.writeBuffer(inputBuffer, 0, inputData);
const outputBuffer = device.createBuffer({
label: `prefix sum output buffer`,
size: inputData.byteLength,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC,
});
const readBuffer = device.createBuffer({
label: 'prefix sum read buffer',
size: inputData.byteLength,
usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST,
});
const module = device.createShaderModule({
code: shader,
});
const pipeline = device.createComputePipeline({
layout: 'auto',
compute: {
module,
},
});
const bindGroup1 = device.createBindGroup({
label: `prefix sum bind group 1`,
layout: pipeline.getBindGroupLayout(0),
entries: [
{ binding: 0, resource: { buffer: inputBuffer } },
{ binding: 1, resource: { buffer: outputBuffer } },
{ binding: 2, resource: { buffer: iterationBuffer } },
],
});
const bindGroup2 = device.createBindGroup({
label: `prefix sum bind group 2`,
layout: pipeline.getBindGroupLayout(0),
entries: [
{ binding: 0, resource: { buffer: outputBuffer } },
{ binding: 1, resource: { buffer: inputBuffer } },
{ binding: 2, resource: { buffer: iterationBuffer } },
],
});
/********************************************************************
* Execute
********************************************************************/
const encoder = device.createCommandEncoder({ label: `prefix sum encoder` });
for (let iteration = 0; iteration <= iterations; iteration++) {
device.queue.writeBuffer(iterationBuffer, 0, new Uint32Array([iteration]));
const pass = encoder.beginComputePass({ label: `prefix sum pass iteration ${iteration}` });
pass.setPipeline(pipeline);
if (iteration % 2 === 0) {
pass.setBindGroup(0, bindGroup1);
} else {
pass.setBindGroup(0, bindGroup2);
}
pass.dispatchWorkgroups(input.length);
pass.end();
}
/********************************************************************
* Output
********************************************************************/
const lastOutputBuffer = iterations % 2 === 0 ? inputBuffer : outputBuffer;
encoder.copyBufferToBuffer(lastOutputBuffer, 0, readBuffer, 0, readBuffer.size);
const commandBuffer = encoder.finish();
device.queue.submit([commandBuffer]);
await readBuffer.mapAsync(GPUMapMode.READ);
const result = new Float32Array(readBuffer.getMappedRange());
// readBuffer.unmap();
return structuredClone(result);
}
I'm observing some weird behaviour, though.
For the input [1, 2, 3, 4, 5, 6, 7, 8]
, I'm getting:
[1, 2, 3, 4, 5, 6, 7, 8]
as expected,[1, 3, 5, 7, 9, 11, 13, 15]
as expected[1, 2, 4, 6, 8, 10, 12, 14]
iteration
= 1) had been applied with the input-buffer from the 0th iteration as input and the input-buffer to the 1st iteration as output[1, 2, 5, 8, 12, 16, 20, 24]
... as expected given the false input[1, 2, 3, 4, 7, 10, 13, 16]
iteration
= 2) had been applied with the input-buffer from the 0th iteration as input and the input-buffer to the 2nd iteration as output[1, 2, 3, 4, 8, 12, 16, 20]
... again as expected given the false inputI tried a few things that I thought might have caused my issue.
I don't know if this is the only issue but one issue I see is you need a iterationBuffer
per iteration.
As it is, this code
for (let iteration = 0; iteration < iterations; iteration++) {
device.queue.writeBuffer(iterationBuffer, 0, new Uint32Array([iteration]));
const pass = encoder.beginComputePass({ label: `prefix sum pass iteration ${iteration}` });
pass.setPipeline(pipeline);
if (iteration % 2 === 0) {
pass.setBindGroup(0, bindGroup1);
} else {
pass.setBindGroup(0, bindGroup2);
}
pass.dispatchWorkgroups(input.length);
pass.end();
}
...
device.queue.submit(...)
Is effectively doing this
const interationBuffer = [];
for (let iteration = 0; iteration < iterations; iteration++) {
iterationBuffer[0] = iteration
encode commands
}
execute (submit) commands
At the execute commands
point above, iterationBuffer
only contains the last value you wrote into it. If you want a different value for iteration
each time you call dispatchWorkgroup
in the same submit
you need a different buffer to hold each value.