GPU-Accelerated Matrix Operations with WebGPU

Matrix multiplication is the most fundamental operation in deep learning — every linear layer, attention head, and convolution boils down to matmul. GPUs excel at this because matrix multiplication is embarrassingly parallel: each output element can be computed independently.

Matrix Multiplication Recap

For matrices and , the product where is:

This requires multiply-add operations. On a CPU, a naive implementation has poor cache locality. On a GPU, we can compute all output elements in parallel.

Memory Layout

GPUs work with flat buffers. We store matrices in row-major order:

This means element of a matrix lives at buffer offset .

Understanding memory layout is critical for performance — accessing contiguous memory (coalesced access) is dramatically faster than strided access on GPUs.

WGSL Compute Shader

Here’s a basic matrix multiplication compute shader in WGSL (WebGPU Shading Language):

struct Matrix {
  rows: u32,
  cols: u32,
  data: array<f32>,
}

@group(0) @binding(0) var<storage, read> a: Matrix;
@group(0) @binding(1) var<storage, read> b: Matrix;
@group(0) @binding(2) var<storage, read_write> result: Matrix;

@compute @workgroup_size(8, 8)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
  let row = gid.x;
  let col = gid.y;

  if (row >= a.rows || col >= b.cols) {
    return;
  }

  var sum: f32 = 0.0;
  for (var k: u32 = 0u; k < a.cols; k = k + 1u) {
    let a_val = a.data[row * a.cols + k];
    let b_val = b.data[k * b.cols + col];
    sum = sum + a_val * b_val;
  }

  result.data[row * b.cols + col] = sum;
}

Each thread computes a single element of the output matrix. The @workgroup_size(8, 8) directive creates 8×8 thread groups — 64 threads that execute together and can share local memory.

WebGPU Pipeline

Setting up the compute pipeline in TypeScript involves:

// 1. Get the GPU device
const adapter = await navigator.gpu.requestAdapter();
const device = await adapter!.requestDevice();

// 2. Create buffers
const bufferA = device.createBuffer({
  size: matrixA.byteLength,
  usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
});
device.queue.writeBuffer(bufferA, 0, matrixA);

// 3. Create compute pipeline
const pipeline = device.createComputePipeline({
  layout: 'auto',
  compute: {
    module: device.createShaderModule({ code: shaderCode }),
    entryPoint: 'main',
  },
});

// 4. Dispatch
const encoder = device.createCommandEncoder();
const pass = encoder.beginComputePass();
pass.setPipeline(pipeline);
pass.setBindGroup(0, bindGroup);
pass.dispatchWorkgroups(
  Math.ceil(M / 8),
  Math.ceil(N / 8)
);
pass.end();
device.queue.submit([encoder.finish()]);

Interactive Demo

Explore matrix multiplication with different sizes — see how the dimensions affect the result:

Checking WebGPU support...

Performance Considerations

Tiling

The naive shader above has poor memory access patterns. Tiled matmul loads sub-matrices into workgroup shared memory:

This reduces global memory reads from to .

Workgroup Size Tuning

The optimal @workgroup_size depends on the GPU. Common choices:

  • 64 threads (8×8) — good default for most GPUs
  • 256 threads (16×16) — better occupancy on discrete GPUs
  • 32 threads (warp/wavefront size) — minimum for full SIMD utilization

Key Takeaways

  • GPU matmul maps each output element to a thread — massive parallelism
  • Row-major memory layout and coalesced access patterns are critical for performance
  • WGSL provides a portable shader language that runs on all WebGPU-capable browsers
  • Tiling and shared memory optimization can improve throughput by 10-50×
  • TSTorch uses these building blocks to implement tensor operations entirely in the browser