GPU-Accelerated Matrix Operations with WebGPU
Matrix multiplication is the most fundamental operation in deep learning — every linear layer, attention head, and convolution boils down to matmul. GPUs excel at this because matrix multiplication is embarrassingly parallel: each output element can be computed independently.
Matrix Multiplication Recap
For matrices and , the product where is:
This requires multiply-add operations. On a CPU, a naive implementation has poor cache locality. On a GPU, we can compute all output elements in parallel.
Memory Layout
GPUs work with flat buffers. We store matrices in row-major order:
This means element of a matrix lives at buffer offset .
Understanding memory layout is critical for performance — accessing contiguous memory (coalesced access) is dramatically faster than strided access on GPUs.
WGSL Compute Shader
Here’s a basic matrix multiplication compute shader in WGSL (WebGPU Shading Language):
struct Matrix {
rows: u32,
cols: u32,
data: array<f32>,
}
@group(0) @binding(0) var<storage, read> a: Matrix;
@group(0) @binding(1) var<storage, read> b: Matrix;
@group(0) @binding(2) var<storage, read_write> result: Matrix;
@compute @workgroup_size(8, 8)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let row = gid.x;
let col = gid.y;
if (row >= a.rows || col >= b.cols) {
return;
}
var sum: f32 = 0.0;
for (var k: u32 = 0u; k < a.cols; k = k + 1u) {
let a_val = a.data[row * a.cols + k];
let b_val = b.data[k * b.cols + col];
sum = sum + a_val * b_val;
}
result.data[row * b.cols + col] = sum;
}
Each thread computes a single element of the output matrix. The @workgroup_size(8, 8) directive creates 8×8 thread groups — 64 threads that execute together and can share local memory.
WebGPU Pipeline
Setting up the compute pipeline in TypeScript involves:
// 1. Get the GPU device
const adapter = await navigator.gpu.requestAdapter();
const device = await adapter!.requestDevice();
// 2. Create buffers
const bufferA = device.createBuffer({
size: matrixA.byteLength,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
});
device.queue.writeBuffer(bufferA, 0, matrixA);
// 3. Create compute pipeline
const pipeline = device.createComputePipeline({
layout: 'auto',
compute: {
module: device.createShaderModule({ code: shaderCode }),
entryPoint: 'main',
},
});
// 4. Dispatch
const encoder = device.createCommandEncoder();
const pass = encoder.beginComputePass();
pass.setPipeline(pipeline);
pass.setBindGroup(0, bindGroup);
pass.dispatchWorkgroups(
Math.ceil(M / 8),
Math.ceil(N / 8)
);
pass.end();
device.queue.submit([encoder.finish()]);
Interactive Demo
Explore matrix multiplication with different sizes — see how the dimensions affect the result:
A
| 10 | 5 | 2 |
| 6 | 1 | 4 |
| 7 | 10 | 5 |
B
| 7 | 0 | 3 |
| 5 | 9 | 8 |
| 8 | 2 | 3 |
A × B
| 111 | 49 | 76 |
| 79 | 17 | 38 |
| 139 | 100 | 116 |
Performance Considerations
Tiling
The naive shader above has poor memory access patterns. Tiled matmul loads sub-matrices into workgroup shared memory:
This reduces global memory reads from to .
Workgroup Size Tuning
The optimal @workgroup_size depends on the GPU. Common choices:
- 64 threads (8×8) — good default for most GPUs
- 256 threads (16×16) — better occupancy on discrete GPUs
- 32 threads (warp/wavefront size) — minimum for full SIMD utilization
Key Takeaways
- GPU matmul maps each output element to a thread — massive parallelism
- Row-major memory layout and coalesced access patterns are critical for performance
- WGSL provides a portable shader language that runs on all WebGPU-capable browsers
- Tiling and shared memory optimization can improve throughput by 10-50×
- TSTorch uses these building blocks to implement tensor operations entirely in the browser