🧮

[JavaScript] WebGPUで半精度浮動小数点 (16bit float) を使う

2024/02/17に公開

先日から開発しているWebGPUを利用したNDArrayで利用するために調べた内容です。

https://zenn.dev/ymd_h/articles/0ed6f287eb3699
https://github.com/ymd-h/gpu-array-js

GPUが "shader-f16" をサポートしているかチェックする

const adapter = await navigator.gpu.requestAdapter();
adapter.features.has("shader-f16"); // サポートしていれば true を返す

"shader-f16" 有効化して (論理) GPUデバイスを作成する

const device = await adapter.requestDevice({ requiredFeatures: ["shader-f16"] });

WebGPU Shading Language (WGSL) で f16 拡張を有効化する

GPUデバイスで機能を有効化するだけでなく、シェーダー側でも拡張機能を有効化してコンパイル (GPUDevice.createShaderModule()) する必要があります。
(私はこの仕様に気が付かなくてちょっと躓きました。)

シェーダーの冒頭に enable f16; と記載してf16拡張機能を有効化します。(参照)

enable f16;

@group(0) @binding(0)
var<storage, read> a: array<f16>;

@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) id: vec3<u32>){
  (省略)
}

CPU側で 16bit float を取り扱う

残念ながら、半精度浮動小数点を保管する Float16Array は (まだ) 標準にはありません。
ありがたいことに ponyfill を公開してくださっている方がいるので、私はこちらを利用することにしました。

https://github.com/petamoriken/float16
https://inside.pixiv.blog/2023/10/19/130000
https://github.com/tc39/proposal-float16array

注意点として、標準のTypedArrayではないので、GPUに転送するときには内部の ArrayBuffer を渡しましょう。

import { Float16Array } from "https://cdn.jsdelivr.net/npm/@petamoriken/float16/+esm";

const f16 = new Float16Array([1.5, 2.3]);
const buffer = device.createBuffer({ size: 2 * 2, usage: GPUBufferUsage.Storage | GPUBufferUsage.COPY_DST });

device.queue.writeBuffer(buffer, 0, f16.buffer);

Discussion