◾
WebGPUでライフゲーム
WebGPUのCompute Shaderを使ってライフゲームを実装します。
仕組みとしては、WebGPUのStorage Bufferで各セルの状態を管理し、Compute Shaderを使ってGPU上でセルの状態を更新します。描画では正方形のメッシュをインスタンシングで画面に敷き詰め、セルの状態を管理するStorage Bufferをもとに色で生死を表現しています。
ソースコードはGitHubに置いてあるので、そちらを確認してください。
実際に動くデモはこちらにあります。
まず初めに、WebGPUを利用するための初期処理をします。
const canvas = document.querySelector('#canvas') as HTMLCanvasElement
const adapter = await navigator.gpu.requestAdapter()
if (!adapter) {
throw new Error('WebGPU is not supported')
}
const device = await adapter.requestDevice()
if (!device) {
throw new Error('WebGPU is not supported')
}
const context = canvas.getContext('webgpu')
if (!context) {
throw new Error('WebGPU is not supported')
}
const canvasFormat = navigator.gpu.getPreferredCanvasFormat()
context.configure({
device,
format: canvasFormat
})
Compute Shaderを実行するのに必要なオブジェクトの作成をします。
// 初期状態をランダムに作成
const initialStates = new Uint32Array(Array.from({ length: numAllCells }, () => Math.random() > 0.5 ? 1 : 0))
// Storage Bufferを作成し、状態を書き込む
let computeReadBuffer = device.createBuffer({
size: initialStates.byteLength,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST
})
device.queue.writeBuffer(computeReadBuffer, 0, initialStates)
// Compute Shaderで次状態を書き込むためのStorage Bufferを作成する
// 2つのStorage Bufferをスワップしながら更新処理を行う
let computeWriteBuffer = device.createBuffer({
size: initialStates.byteLength,
usage: GPUBufferUsage.STORAGE // writeBufferしないのでGPUBufferUsage.COPY_DSTは不要
})
// Compute Shader用のShader Moduleの作成
const computeShaderModule = device.createShaderModule({
code: computeShaderCode
})
// Compute Pipelineの作成
const computePipeline = device.createComputePipeline({
layout: 'auto',
compute: {
module: computeShaderModule,
entryPoint: 'main',
constants: {
'numCells': numCells,
}
}
})
Compute ShaderのWGSLは以下のようになっています。ライフゲームが2次元空間に存在するので、Workgroupのサイズも[8, 8, 1]
というようにして、Compute Shaderが2次元的に実行されるようにしています。ただし、Storage Bufferは1次元なので、各セルの状態を取得するために2次元座標からインデックスを計算する必要があります。
@group(0) @binding(0) var<storage, read> currentStates: array<u32>;
@group(0) @binding(1) var<storage, read_write> nextStates: array<u32>;
override numCells: u32 = 512;
// 上下と左右はそれぞれ繋がっているとしてStorage Bufferのインデックスを求める
fn calcIndex(center: vec2<i32>, offset: vec2<i32>) -> i32 {
let iNumCells = i32(numCells);
var x = center.x + offset.x;
if (x == -1) {
x = iNumCells - 1;
} else if (x == iNumCells) {
x = 0;
}
var y = center.y + offset.y;
if (y == -1) {
y = iNumCells - 1;
} else if (y ==iNumCells) {
y = 0;
}
return x + y * iNumCells;
}
@compute @workgroup_size(8, 8, 1)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let stateIndex = gid.x + gid.y * u32(numCells);
let currentState = currentStates[stateIndex]; // 現在の状態
// 周囲8マスから生存しているセルの数を集計する
let center = vec2i(gid.xy);
var aliveNeighbors: u32 = 0;
aliveNeighbors += currentStates[calcIndex(center, vec2(-1, -1))];
aliveNeighbors += currentStates[calcIndex(center, vec2(0, -1))];
aliveNeighbors += currentStates[calcIndex(center, vec2(1, -1))];
aliveNeighbors += currentStates[calcIndex(center, vec2(-1, 0))];
aliveNeighbors += currentStates[calcIndex(center, vec2(1, 0))];
aliveNeighbors += currentStates[calcIndex(center, vec2(-1, 1))];
aliveNeighbors += currentStates[calcIndex(center, vec2(0, 1))];
aliveNeighbors += currentStates[calcIndex(center, vec2(1, 1))];
if ((currentState == 0 && aliveNeighbors == 3) || (currentState == 1 && (aliveNeighbors == 2 || aliveNeighbors == 3))) {
nextStates[stateIndex] = 1; // 生存
} else {
nextStates[stateIndex] = 0; // 死亡
}
}
Compute Shaderの準備ができたので、次にライフゲームの描画に必要なオブジェクトの作成します。基本的にはWebGPUで通常の描画処理を行うときと同じです。
// 正方形メッシュの頂点
const squareVertices = new Float32Array([
-1, 1,
-1, -1,
1, 1,
1, -1
])
// 正方形メッシュのインデックス
const squareIndices = new Uint16Array([
0, 1, 2,
2, 1, 3
])
// 描画用のShader Moduleの作成
const renderShaderModule = device.createShaderModule({
code: renderShaderCode
})
// Vertex Bufferの作成
const vertexBuffer = device.createBuffer({
size: squareVertices.byteLength,
usage: GPUBufferUsage.VERTEX | GPUBufferUsage.COPY_DST,
})
device.queue.writeBuffer(vertexBuffer, 0, squareVertices)
// Index Bufferの作成
const indexBuffer = device.createBuffer({
size: squareIndices.byteLength,
usage: GPUBufferUsage.INDEX | GPUBufferUsage.COPY_DST,
})
device.queue.writeBuffer(indexBuffer, 0, squareIndices)
// Render Pipelineの作成
const renderPipeline = device.createRenderPipeline({
layout: 'auto',
vertex: {
module: renderShaderModule,
entryPoint: 'vs',
buffers: [{
attributes: [{
shaderLocation: 0,
offset: 0,
format: 'float32x2'
}],
arrayStride: 8
}],
constants: {
'numCells': numCells
}
},
fragment: {
module: renderShaderModule,
entryPoint: 'fs',
targets: [{
format: canvasFormat
}]
}
})
// Bind Groupの作成
const renderBindGroup = device.createBindGroup({
layout: renderPipeline.getBindGroupLayout(0),
entries: [
{
binding: 0,
resource: { buffer: computeReadBuffer }
}
]
})
描画に利用するWGSLコードは以下のようになっています。instance_index
をもとに位置を決定し、Compute Shaderで更新するStorage Bufferを参照して色を決定しています。
override numCells: u32 = 512;
@group(0) @binding(0) var<storage, read> states: array<u32>;
struct VertexOutput {
@builtin(position) position: vec4f,
@location(0) color: vec3f
}
@vertex
fn vs(@location(0) position: vec2f, @builtin(instance_index) instanceIndex: u32) -> VertexOutput {
let cellSize = 2.0 / f32(numCells);
let vertexPos = position * cellSize * 0.5;
let x = instanceIndex % numCells;
let y = instanceIndex / numCells;
let cellPos = vec2f(-1, -1) + vec2f(f32(x) + 0.5, f32(y) + 0.5) * cellSize;
var output: VertexOutput;
output.position = vec4f(vertexPos + cellPos, 0, 1);
output.color = select(vec3f(0, 0, 0), vec3f(1, 1, 1), states[instanceIndex] == 1);
return output;
}
@fragment
fn fs(vo: VertexOutput) -> @location(0) vec4f {
return vec4f(vo.color.x, vo.color.x, vo.color.x, 1);
}
更新処理は以下のようになります。このupdate
メソッドをrequestAnimationFrame
の中で呼び出してアニメーションするようにします。
const workgroupSize = [8, 8]
const update = () => {
// Compute Shaderによる更新処理
const computeCommandEncoder = device.createCommandEncoder()
const computePassEncoder = computeCommandEncoder.beginComputePass()
computePassEncoder.setPipeline(computePipeline)
const computeBindGroup = device.createBindGroup({
layout: computePipeline.getBindGroupLayout(0),
entries: [
{binding:0, resource: { buffer: computeReadBuffer }},
{binding:1, resource: { buffer: computeWriteBuffer }}
]
})
computePassEncoder.setBindGroup(0, computeBindGroup)
computePassEncoder.dispatchWorkgroups(Math.ceil(numCells / workgroupSize[0]), Math.ceil(numCells / workgroupSize[1]), 1)
computePassEncoder.end()
device.queue.submit([computeCommandEncoder.finish()])
// Storage Bufferをスワップする
;[computeReadBuffer, computeWriteBuffer] = [computeWriteBuffer, computeReadBuffer]
// 描画処理
const encoder = device.createCommandEncoder()
const renderPassEncoder = encoder.beginRenderPass({
colorAttachments: [{
view: context.getCurrentTexture().createView(),
loadOp: 'clear',
storeOp: 'store',
}]
})
renderPassEncoder.setPipeline(renderPipeline)
renderPassEncoder.setBindGroup(0, renderBindGroup)
renderPassEncoder.setVertexBuffer(0, vertexBuffer)
renderPassEncoder.setIndexBuffer(indexBuffer, 'uint16')
renderPassEncoder.drawIndexed(squareIndices.length, numAllCells)
renderPassEncoder.end()
device.queue.submit([encoder.finish()])
}
Discussion