WebGPUでライフゲーム

2024/06/29に公開

WebGPUのCompute Shaderを使ってライフゲームを実装します。

https://ja.wikipedia.org/wiki/ライフゲーム

仕組みとしては、WebGPUのStorage Bufferで各セルの状態を管理し、Compute Shaderを使ってGPU上でセルの状態を更新します。描画では正方形のメッシュをインスタンシングで画面に敷き詰め、セルの状態を管理するStorage Bufferをもとに色で生死を表現しています。

ソースコードはGitHubに置いてあるので、そちらを確認してください。
https://github.com/aadebdeb/webgpu-lifegame

実際に動くデモはこちらにあります。
https://aadebdeb.github.io/webgpu-lifegame/


まず初めに、WebGPUを利用するための初期処理をします。

const canvas = document.querySelector('#canvas') as HTMLCanvasElement
const adapter = await navigator.gpu.requestAdapter()
if (!adapter) {
  throw new Error('WebGPU is not supported')
}
const device = await adapter.requestDevice()
if (!device) {
  throw new Error('WebGPU is not supported')
}
const context = canvas.getContext('webgpu')
if (!context) {
  throw new Error('WebGPU is not supported')
}

const canvasFormat = navigator.gpu.getPreferredCanvasFormat()
context.configure({
  device,
  format: canvasFormat
})

Compute Shaderを実行するのに必要なオブジェクトの作成をします。

// 初期状態をランダムに作成
const initialStates = new Uint32Array(Array.from({ length: numAllCells }, () => Math.random() > 0.5 ? 1 : 0))
// Storage Bufferを作成し、状態を書き込む
let computeReadBuffer = device.createBuffer({
  size: initialStates.byteLength,
  usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST
})
device.queue.writeBuffer(computeReadBuffer, 0, initialStates)
// Compute Shaderで次状態を書き込むためのStorage Bufferを作成する
// 2つのStorage Bufferをスワップしながら更新処理を行う
let computeWriteBuffer = device.createBuffer({
  size: initialStates.byteLength,
  usage: GPUBufferUsage.STORAGE // writeBufferしないのでGPUBufferUsage.COPY_DSTは不要
})
// Compute Shader用のShader Moduleの作成
const computeShaderModule = device.createShaderModule({
  code: computeShaderCode
})
// Compute Pipelineの作成
const computePipeline = device.createComputePipeline({
  layout: 'auto',
  compute: {
    module: computeShaderModule,
    entryPoint: 'main',
    constants: {
      'numCells': numCells,
    }
  }
})

Compute ShaderのWGSLは以下のようになっています。ライフゲームが2次元空間に存在するので、Workgroupのサイズも[8, 8, 1]というようにして、Compute Shaderが2次元的に実行されるようにしています。ただし、Storage Bufferは1次元なので、各セルの状態を取得するために2次元座標からインデックスを計算する必要があります。

@group(0) @binding(0) var<storage, read> currentStates: array<u32>;
@group(0) @binding(1) var<storage, read_write> nextStates: array<u32>;

override numCells: u32 = 512;

// 上下と左右はそれぞれ繋がっているとしてStorage Bufferのインデックスを求める
fn calcIndex(center: vec2<i32>, offset: vec2<i32>) -> i32 {
  let iNumCells = i32(numCells);
  var x = center.x + offset.x;
  if (x == -1) {
    x = iNumCells - 1;
  } else if (x == iNumCells) {
    x = 0;
  }
  var y = center.y + offset.y;
  if (y == -1) {
    y = iNumCells - 1;
  } else if (y ==iNumCells) {
    y = 0;
  }
  return x + y * iNumCells;
}

@compute @workgroup_size(8, 8, 1)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
  let stateIndex = gid.x + gid.y * u32(numCells);
  let currentState = currentStates[stateIndex]; // 現在の状態

  // 周囲8マスから生存しているセルの数を集計する
  let center = vec2i(gid.xy);
  var aliveNeighbors: u32 = 0;
  aliveNeighbors += currentStates[calcIndex(center, vec2(-1, -1))];
  aliveNeighbors += currentStates[calcIndex(center, vec2(0, -1))];
  aliveNeighbors += currentStates[calcIndex(center, vec2(1, -1))];
  aliveNeighbors += currentStates[calcIndex(center, vec2(-1, 0))];
  aliveNeighbors += currentStates[calcIndex(center, vec2(1, 0))];
  aliveNeighbors += currentStates[calcIndex(center, vec2(-1, 1))];
  aliveNeighbors += currentStates[calcIndex(center, vec2(0, 1))];
  aliveNeighbors += currentStates[calcIndex(center, vec2(1, 1))];

  if ((currentState == 0 && aliveNeighbors == 3) || (currentState == 1 && (aliveNeighbors == 2 || aliveNeighbors == 3))) {
    nextStates[stateIndex] = 1; // 生存
  } else {
    nextStates[stateIndex] = 0; // 死亡
  }
}

Compute Shaderの準備ができたので、次にライフゲームの描画に必要なオブジェクトの作成します。基本的にはWebGPUで通常の描画処理を行うときと同じです。

// 正方形メッシュの頂点
const squareVertices = new Float32Array([
  -1, 1,
  -1, -1,
  1, 1,
  1, -1
])
// 正方形メッシュのインデックス
const squareIndices = new Uint16Array([
  0, 1, 2,
  2, 1, 3
])
// 描画用のShader Moduleの作成
const renderShaderModule = device.createShaderModule({
  code: renderShaderCode
})
// Vertex Bufferの作成
const vertexBuffer = device.createBuffer({
  size: squareVertices.byteLength,
  usage: GPUBufferUsage.VERTEX | GPUBufferUsage.COPY_DST,
})
device.queue.writeBuffer(vertexBuffer, 0, squareVertices)
// Index Bufferの作成
const indexBuffer = device.createBuffer({
  size: squareIndices.byteLength,
  usage: GPUBufferUsage.INDEX | GPUBufferUsage.COPY_DST,
})
device.queue.writeBuffer(indexBuffer, 0, squareIndices)
// Render Pipelineの作成
const renderPipeline = device.createRenderPipeline({
  layout: 'auto',
  vertex: {
    module: renderShaderModule,
    entryPoint: 'vs',
    buffers: [{
      attributes: [{
        shaderLocation: 0,
        offset: 0,
        format: 'float32x2'
      }],
      arrayStride: 8
    }],
    constants: {
      'numCells': numCells
    }
  },
  fragment: {
    module: renderShaderModule,
    entryPoint: 'fs',
    targets: [{
      format: canvasFormat
    }]
  }
})
// Bind Groupの作成
const renderBindGroup = device.createBindGroup({
  layout: renderPipeline.getBindGroupLayout(0),
  entries: [
    {
      binding: 0,
      resource: { buffer: computeReadBuffer }
    }
  ]
})

描画に利用するWGSLコードは以下のようになっています。instance_indexをもとに位置を決定し、Compute Shaderで更新するStorage Bufferを参照して色を決定しています。

override numCells: u32 = 512;

@group(0) @binding(0) var<storage, read> states: array<u32>;

struct VertexOutput {
  @builtin(position) position: vec4f,
  @location(0) color: vec3f
}

@vertex
fn vs(@location(0) position: vec2f, @builtin(instance_index) instanceIndex: u32) -> VertexOutput {
  let cellSize = 2.0 / f32(numCells);
  let vertexPos = position * cellSize * 0.5;

  let x = instanceIndex % numCells;
  let y = instanceIndex / numCells;
  let cellPos = vec2f(-1, -1) + vec2f(f32(x) + 0.5, f32(y) + 0.5) * cellSize;

  var output: VertexOutput;
  output.position = vec4f(vertexPos + cellPos, 0, 1);
  output.color = select(vec3f(0, 0, 0), vec3f(1, 1, 1), states[instanceIndex] == 1);
  return output;
}

@fragment
fn fs(vo: VertexOutput) -> @location(0) vec4f {
  return vec4f(vo.color.x, vo.color.x, vo.color.x, 1);
}

更新処理は以下のようになります。このupdateメソッドをrequestAnimationFrameの中で呼び出してアニメーションするようにします。

const workgroupSize = [8, 8]
const update = () => {
  // Compute Shaderによる更新処理
  const computeCommandEncoder = device.createCommandEncoder()
  const computePassEncoder = computeCommandEncoder.beginComputePass()
  computePassEncoder.setPipeline(computePipeline)
  const computeBindGroup = device.createBindGroup({
    layout: computePipeline.getBindGroupLayout(0),
    entries: [
      {binding:0, resource: { buffer: computeReadBuffer }},
      {binding:1, resource: { buffer: computeWriteBuffer }}
    ]
  })
  computePassEncoder.setBindGroup(0, computeBindGroup)
  computePassEncoder.dispatchWorkgroups(Math.ceil(numCells  / workgroupSize[0]), Math.ceil(numCells / workgroupSize[1]), 1)
  computePassEncoder.end()
  device.queue.submit([computeCommandEncoder.finish()])

  // Storage Bufferをスワップする
  ;[computeReadBuffer, computeWriteBuffer] = [computeWriteBuffer, computeReadBuffer]

  // 描画処理
  const encoder = device.createCommandEncoder()
  const renderPassEncoder = encoder.beginRenderPass({
    colorAttachments: [{
      view: context.getCurrentTexture().createView(),
      loadOp: 'clear',
      storeOp: 'store',
    }]
  })
  renderPassEncoder.setPipeline(renderPipeline)
  renderPassEncoder.setBindGroup(0, renderBindGroup)
  renderPassEncoder.setVertexBuffer(0, vertexBuffer)
  renderPassEncoder.setIndexBuffer(indexBuffer, 'uint16')
  renderPassEncoder.drawIndexed(squareIndices.length, numAllCells)
  renderPassEncoder.end()
  device.queue.submit([encoder.finish()])
}

Discussion