🧮

WebGPUでCPUより高速に累積和を取りたいな

に公開

概要

WebGPUのコンピュートシェーダーを使って累積和を計算します。CPUによるシーケンシャルな累積和よりも高速に動作することを目標とします。実装にはRustとwgpuを使います。WebGPUなんじゃそらという方はWebGPU fundamentalsのコンピュートシェーダー周りを覗いてからだと読みやすくなると思います。subgroup以外のセクションは主にChapter 39. Parallel Prefix Sum (Scan) with CUDA | NVIDIA Developerを参考にしています。

実装全体はこちらからどうぞ。
https://github.com/YohYamasaki/wgpu-prefix-sum-demo

累積和とは

数列Aがあるとき、S[i] = S[i-1] + A[i]のようにi部分までの数列の総和を取る事によって得られる数列Sのことを累積和と呼びます。基数ソートや四分木など、様々なアルゴリズムやデータ構造の土台として使われています。CPUでin placeに実装すると以下の擬似コードのように書けますね。

function prefixSum(A[0..n-1])
	for i:= 1,..,n-1 do
		A[i] := A[i-1] + A[i]
	end for
end function

累積和を実直に実装する場合、今見ている要素と一つ前の要素を足します。つまり全部の要素がそれよりも前の要素の結果に依存するため、そのままだと並列処理でのメリットを享受することができません。そのため並列処理に向いた累積和アルゴリズムが開発されています。

Hillis-Steele scan

二分木的に各ステップで依存距離が{2^k}に増える反復アルゴリズムです。S[i]を求めたいとき、A[0..i]が葉でS[i]を根とした二分木を考えます。この二分木を葉から順番に足し合わせて親に保存していくと、最終的にS[i]\Sigma_{k}^{i}{A[k]} が入ることがわかると思います。そして嬉しいことに各高さにおいての局所依存は2要素になることから、並列処理に向いています。説明下手で何を言っているのか分からないかもしれませんが、図を見てもらうと分かりやすいかなと思います。説明しやすさのためにInclusive scanにしています。

Hillis-Steele scan

実装

早速wgpuで実装していきましょう。最初なのでじっくり目に見ていきます。

Hillis-Steele scanは基本的にin place前提のアルゴリズムですが、一本のバッファを使い回すことは実際にはできません。前段の値ではなく更新中の値を読んでしまう競合が発生するためです。そのためダブルバッファで実装していきます。ダブルバッファのバインドグループは以下の様な感じですね。

hillis_steele_scan.rs
let byte_len = (n * size_of::<u32>()) as u64;  
  
let data0 = device.create_buffer(&wgpu::BufferDescriptor {  
    label: Some("data0"),  
    size: byte_len,  
    usage: wgpu::BufferUsages::STORAGE  
        | wgpu::BufferUsages::COPY_DST  
        | wgpu::BufferUsages::COPY_SRC,  
    mapped_at_creation: false,  
});  
let data1 = device.create_buffer(&wgpu::BufferDescriptor {  
    label: Some("data1"),  
    size: byte_len,  
    usage: wgpu::BufferUsages::STORAGE  
        | wgpu::BufferUsages::COPY_DST  
        | wgpu::BufferUsages::COPY_SRC,  
    mapped_at_creation: false,  
});

次に今のステップ数を渡すためのUniform Bufferを作ります。スキャンが進む事に、2のべき乗ずつ増えていくステップ数の分だけ前の要素を読み込んで足すためです。で、シェーダー側にこうした固定の変数的なデータを渡すために、Uniform bufferが用意されていますので使いましょう。

シェーダーを呼び出すたびにステップを計算して同じUniformバッファに詰め直してsubmitを繰り返せばいいんじゃないのと思いますが、submit実行時にはコマンド列の検証などコストゼロの処理ではないので出来るだけ少ない回数にまとめたほうが良いです。このステップ数も事前に計算しUniformを複数作って配列に詰めておくことで、copy_buffer_to_bufferからのsubmitを行う必要がなくなり複数まとめてdispatchができるため、submitが一回で済みます。

hillis_steele_scan.rs
// Calculate stride between uniforms in the aggregate buffer  
let align = device.limits().min_uniform_buffer_offset_alignment as usize;  
let uni_size = size_of::<Uniforms>();  
let stride = align_up(uni_size, align);  
let uniform_stride = stride as u32;  
  
// Create a byte array of the uniforms with the stride  
let mut blob = vec![0u8; stride * (max_steps as usize)];  
for i in 0..max_steps {  
    let u = Uniforms {  
        step: 1u32 << i,  
        _pad: [0; 3],  
    };  
    let bytes = bytemuck::bytes_of(&u);  
    let offset = (i as usize) * stride;  
    blob[offset..offset + bytes.len()].copy_from_slice(bytes);  
}  
  
let uniform = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {  
    label: Some("uniform"),  
    contents: &blob,  
    usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,  
});

そうすると、各ステップでUniform配列を読み始めるオフセットを変更する必要があります。webGPUではDynamic offset機能があるので、それをオンにします。そのためにはバインドグループレイアウトとパイプラインレイアウトを明示してあげる必要があります。

hillis_steele_scan.rs
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {  
    label: Some("prefix-sum bgl"),  
    entries: &[  
        // src: storage read  
        wgpu::BindGroupLayoutEntry {  
            binding: 0,  
            visibility: wgpu::ShaderStages::COMPUTE,  
            ty: wgpu::BindingType::Buffer {  
                ty: wgpu::BufferBindingType::Storage { read_only: true },  
                has_dynamic_offset: false,  
                min_binding_size: None,  
            },  
            count: None,  
        },  
        // dst: storage read_write  
        wgpu::BindGroupLayoutEntry {  
            binding: 1,  
            visibility: wgpu::ShaderStages::COMPUTE,  
            ty: wgpu::BindingType::Buffer {  
                ty: wgpu::BufferBindingType::Storage { read_only: false },  
                has_dynamic_offset: false,  
                min_binding_size: None,  
            },  
            count: None,  
        },  
        // uni: uniform (dynamic offset!)  
        wgpu::BindGroupLayoutEntry {  
            binding: 2,  
            visibility: wgpu::ShaderStages::COMPUTE,  
            ty: wgpu::BindingType::Buffer {  
                ty: wgpu::BufferBindingType::Uniform,  
                // We will store all the steps into one uniform, that requires to have dynamic offset  
                has_dynamic_offset: true,  
                min_binding_size: NonZeroU64::new(size_of::<Uniforms>() as u64),  
            },  
            count: None,  
        },  
    ],  
});  
  
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {  
    label: Some("prefix-sum pipeline layout"),  
    bind_group_layouts: &[&bind_group_layout],  
    immediate_size: 0,  
});

ここまでが初期化でした。次にシェーダー本体と計算実行用の関数を用意します。

初期化まわりのごっちゃり感から比べるとスッキリですね。今のスレッドの要素とステップ分前の要素との加算結果をもう片方のバッファへコピーしています。

hillis_steele_scan.wgsl
struct Uniforms {  
  step: u32,  
};  
  
@group(0) @binding(0) var<storage, read> src: array<u32>;  
@group(0) @binding(1) var<storage, read_write> dst: array<u32>;  
@group(0) @binding(2) var<uniform> uni: Uniforms;  
  
@compute  
@workgroup_size(64)  
fn main(  
  @builtin(global_invocation_id) gid: vec3<u32>,  
  @builtin(num_workgroups) nwg: vec3<u32>,  
) {  
    let total = arrayLength(&src);  
  
    let width = nwg.x * 64u;  
    let i = gid.x + gid.y * width;  
  
    if (i >= total) {  
        return;  
    }  
  
    if (i < uni.step) {  
        dst[i] = src[i];  
    } else {  
        dst[i] = src[i] + src[i - uni.step];  
    }  
}

実行するには必要なステップの総数分dispatchし、まとめてsubmitすればOKです!配列サイズがワークグループの1次元分のサイズに収まらないときのため、yも使うように実装しています。

hillis_steele_scan.rs
  
pub fn run_prefix_scan(&self) {  
    const WG_SIZE: u32 = 64;  
    let workgroups_needed = self.n.div_ceil(WG_SIZE as usize) as u32;  
  
    let max_dim = self.device.limits().max_compute_workgroups_per_dimension;  
    let x = workgroups_needed.min(max_dim);  
    let y = (workgroups_needed + x - 1) / x;  
    let mut encoder = self.device.create_command_encoder(&Default::default());  
    {  
        let mut pass = encoder.begin_compute_pass(&Default::default());  
        pass.set_pipeline(&self.pipeline);  
        for i in 0..self.max_steps {  
            let offset_bytes = i * self.uniform_stride;  
            let bg = if i % 2 == 0 {  
                &self.bind_group_0  
            } else {  
                &self.bind_group_1  
            };  
            pass.set_bind_group(0, bg, &[offset_bytes]);  
            pass.dispatch_workgroups(x, y, 1);  
        }  
    }  
    self.queue.submit([encoder.finish()]);  
}

ベンチマーク

使用マシンはM4 Pro Mac mini 24GBです。計測にはCriterionを使用しました。本来シェーダーの実行時間を取りたければtimestamp-queryが使えればベストですが、Metalバックエンドのバグなのかうまく取得できず、一旦submitから計算が終わってデバイスがまた有効になるまでの時間を計測します。ポーリング分で最大体感500μs~1ms程度オーバーヘッドがあるっぽいですが、今回は仕方ないですしそこまで厳密に計測せずで良いでしょう。デカいNで比較すれば良い話です。

CPU Sequential vs GPU Hillis-Steele

おっっっっっそ!!!!!

Hillis-Steeleは理想並列計算量がO(\log{N})と、CPUよりも早くなっていても良さそうですが、実際にはメモリ帯域幅と仕事量がボトルネックになっているみたいです。

具体的にメモリ転送にかかる時間の理論値を計算してみます。M4 Proはユニファイドメモリの帯域幅が273GB/sです。例えばN = 10^8だと[u32; n]の配列を扱う場合、配列サイズは約400MBになります。そして \lceil \log_{2}{N} \rceil = 27ステップになりますが、各ステップでreadとwriteの合計二回アクセスするので、合計すると54回転送することになります。つまり総転送量は400\text{MB} \times 54 = 21.6\text{GB}、よって転送にかかる理論値は\frac{21.6\text{GB}}{273\text{GB/s}} \approx 0.0791\text{s} = 79.1\text{ms}となります。実際にN = 10^8で計算すると104.36ms掛かったので、理論値で見ても転送時間が多くを占めることがわかりました。

加えて、このアルゴリズムは仕事効率が最適ではなく、一つのスレッドで動かしたと想定するとO(N \log_2{N})が時間計算量です。つまり本来O(N)で回せる累積和計算において、理想と比べると余計な計算が必要ということになります。並列計算によってNが取れるため、使えるスレッドが無限個ある場合の理論計算量はO(\log{N})にはなりますが、実際のところ有限個なので要素数が増えるごとに徐々に遅くなっていくということですね。

Blelloch scan

Hillis-Steeleでボトルネックになっていた2つの問題を改善できるアルゴリズムがBlelloch scanです。仕事量はO(N)、並列計算時の時間計算量はO(\frac{N}{P} + \log{N})になります。(P = スレッド数)

Blelloch scanはup-sweepとdown-sweepの2段階に処理が別れます。up-sweepでは2のべき乗のステップを作り、ステップ数で割り切れるインデックスの要素に一つ前のステップ数分前の要素を足すというのを繰り返します。これも図が分かりやすいですね。Blellochでは簡単のためにExclusive scanにします。

Blelloch up-sweep

down-sweep前には配列の一番最後の要素を0にしておく必要があります。down-sweepもup-sweepと同様に2のべき乗のステップにて似た処理を行いますが、今度は一つ前のインデックスをL、今見ているインデックスをRとしたとき、S[L] = S[R], S[R] = S[R] + S[L]のように更新します。

Blelloch down-sweep.png

何がどうなったらこんなの思いつくんだ、という感じですが、確かに一つの要素を逆順に追っていくとちゃんと累積和になっています。ありがたく使わせていただきましょう。我々は天才達による庇護の元で漸くプログラムを書くことができるのです。

実装

基本的な構成はHillis-Steele用のものと変わりません。ダブルバッファは必要ないのと、up-sweep, 最後の要素をゼロに変更, down-sweepの3つのシェーダーを実効するためのパイプラインとバインドグループを作ればOKです。ステップ数もほぼほぼ同様のアプローチでUniform配列を作っておきます。

そして以下が本丸、up-sweep, down-sweep用のシェーダーたちです。上の図での説明どおりの実装になっていることが確認できると思います。

global_blelloch_scan_up_sweep.wgsl
struct Uniform {  
  step: u32, // this has to be a power of 2  
};  
  
@group(0) @binding(0) var<storage, read_write> data: array<u32>;  
@group(0) @binding(1) var<uniform> uni: Uniform;  
  
@compute  
@workgroup_size(64)  
fn main(  
  @builtin(global_invocation_id) gid: vec3<u32>,  
  @builtin(num_workgroups) nwg: vec3<u32>,  
) {  
    let n = arrayLength(&data);  
    let step = uni.step;  
    let half = step >> 1u;  
    let width = nwg.x * 64u;  
    let plane = width * nwg.y;  
    let t = gid.x + gid.y * width + gid.z * plane;  
  
    let active_idx = n / step;  
    if (t >= active_idx) { return; }  
  
    // We need (step - 1u) to target the last element of the current block  
    let i = (step - 1u) + t * step;  
    data[i] += data[i - half];  
}
global_blelloch_scan_down_sweep.wgsl
struct Uniform {  
  step: u32, // this has to be a power of 2  
};  
  
@group(0) @binding(0) var<storage, read_write> data: array<u32>;  
@group(0) @binding(1) var<uniform> uni: Uniform;  
  
@compute  
@workgroup_size(64)  
fn main(  
  @builtin(global_invocation_id) gid: vec3<u32>,  
  @builtin(num_workgroups) nwg: vec3<u32>,  
) {  
    let n = arrayLength(&data);  
        let step = uni.step;  
        let half = step >> 1u;  
  
        let width = nwg.x * 64u;  
        let plane = width * nwg.y;  
        let t = gid.x + gid.y * width + gid.z * plane;  
  
        let active_idx = n / step;  
        if (t >= active_idx) { return; }  
  
        // We need (step - 1u) to target the last element of the current block 
        let i = (step - 1u) + t * step;  
        let prev = i - half;  
  
        let left = data[i];  
        data[i] = data[i] + data[prev];  
        data[prev] = left;  
}

ベンチマーク

CPU vs Hillis-Steele vs Blelloch

Hillis-Steeleと比べると大分早くなりました。それでもまだ最適とはいえず、グローバルメモリへのアクセスがこの実装ではまだ7N回程度発生していて、改善の余地がありそうです。

Blocked Blelloch scan

グローバルメモリへのアクセスを減らすためにsharedメモリを使いましょう。WebGPUではvar<workgroup>ですね。workgroup単位での同期ができるので、dispatchの回数も減らすことができます。

ブロックサイズはワークグループサイズと同じになるよう設定すればよいでしょう。まず最初にブロックごとに累積和を取り、次にメインのデータとは別に用意したバッファにブロック全要素を合計した値を入れていきます。そしてそのブロック合計用バッファで更に累積和を取るとオフセットになるため、その値をブロック単位で既に累積和が取ってあるメインのデータに加えれば完成という流れです。

Block scan

問題はどうやってブロック合計用バッファの累積和を取るかですが、これもBlelloch scanで取ります。つまりブロック合計用バッファのサイズが一つのブロックに収まるようになるまで、再帰的に累積和を取っていくということになります。

実装

コア部分のアルゴリズムは元のBlelloch scanと同等ですが、workgroup内のsharedメモリをブロック単位のBlelloch scanの頭から終わりまで使う必要があるので、up-sweep / down-sweepのステップ単位のループも含めて一つのシェーダーに移してしまいます。これは再帰的に累積和を適用させるのにも都合がよいです。また、累積和を取ったあとのブロック合計用バッファの値を元のデータに加算するためのシェーダーも作っておきましょう。

blelloch_block_scan.wgsl
const WG_SIZE: u32 = 64u;  
  
@group(0) @binding(0) var<storage, read_write> global_data: array<u32>;  
@group(0) @binding(1) var<storage, read_write> block_sum: array<u32>;  
  
var<workgroup> local_data: array<u32, 64u>;  
  
fn linearize_workgroup_id(wid: vec3<u32>, num_wg: vec3<u32>) -> u32 {  
    // linear = x + y*X + z*(X*Y)  
    return wid.x + wid.y * num_wg.x + wid.z * (num_wg.x * num_wg.y);  
}  
  
/**  
 * Get local and global index. 
 */
fn get_indices(lid: vec3<u32>, wid: vec3<u32>, num_wg: vec3<u32>) -> array<u32, 2> {  
    let local_idx = lid.x;  
    let wg_linear = linearize_workgroup_id(wid, num_wg);  
    let block_base = wg_linear * WG_SIZE;  
    let global_idx = block_base + local_idx;  
    return array<u32, 2>(local_idx, global_idx);  
}  
  
/**  
 * Load data from the storage to the workgroup variable. 
 */
fn copy_global_data_to_local(n: u32, local_idx: u32, global_idx: u32) {  
    var global_val = 0u;  
    if (global_idx < n) {  
        global_val = global_data[global_idx];  
    }  
    local_data[local_idx] = global_val;  
    workgroupBarrier();  
}  
  
/**  
 * Execute up-sweep step of the Blelloch scan. Returns sum of the local block. 
 */
fn up_sweep(local_idx: u32) {  
    var step = 2u;  
    while (step <= WG_SIZE) {  
        let num_targets = WG_SIZE / step;  
        if (local_idx < num_targets) {  
            // Map each participating thread t to the rightmost element of its span.  
            // This avoids an expensive modulo/division check +            
            // makes active lanes contiguous (t < num_targets), which (probably) reduces            
            // intra-warp branch divergence compared to a strided predicate.            
            let target_idx = (local_idx + 1u) * step - 1u;  
            // target_idx - (step >> 1u) -> index of the sum target (step/2 back)  
            local_data[target_idx] += local_data[target_idx - (step >> 1u)];  
        }  
        workgroupBarrier();  
        step = step << 1u;  
    }  
}  
  
/**  
 * Execute down-sweep step of the Blelloch scan. 
 */
fn down_sweep(local_idx: u32) {  
    var step = WG_SIZE;  
    while (step >= 2u) {  
     let num_targets = WG_SIZE / step;  
     if (local_idx < num_targets) {  
         let target_idx = (local_idx + 1u) * step - 1u;  
         let prev_idx = target_idx - (step >> 1u);  
         let prev_val = local_data[prev_idx];  
         local_data[prev_idx] = local_data[target_idx];  
         local_data[target_idx] += prev_val;  
     }  
     workgroupBarrier();  
     step = step >> 1u;  
    }  
}  
  
@compute @workgroup_size(WG_SIZE)  
fn block_scan_write_sum(  
    @builtin(local_invocation_id) lid: vec3<u32>,  
    @builtin(workgroup_id) wid: vec3<u32>,  
    @builtin(num_workgroups) num_wg: vec3<u32>  
) {  
    let n = arrayLength(&global_data);  
    let indices = get_indices(lid, wid, num_wg);  
    let local_idx = indices[0];  
    let global_idx = indices[1];  
    copy_global_data_to_local(n, local_idx, global_idx);  
  
    up_sweep(local_idx);  
  
    // write out the block sum here before overwriting with 0  
    let wg_linear = linearize_workgroup_id(wid, num_wg);  
    let n_blocks = arrayLength(&block_sum);  
    if (local_idx == 0u) {  
        if (wg_linear < n_blocks) {  
            block_sum[wg_linear] = local_data[WG_SIZE - 1u];  
        }  
        local_data[WG_SIZE - 1u] = 0u;  
    }  
    workgroupBarrier();  
  
    down_sweep(local_idx);  
  
    // write out the local scan result to the global storage  
    if (global_idx < n) {  
        global_data[global_idx] = local_data[local_idx];  
    }  
}  
  
@compute @workgroup_size(WG_SIZE)  
fn block_scan_no_sum(  
    @builtin(local_invocation_id) lid: vec3<u32>,  
    @builtin(workgroup_id) wid: vec3<u32>,  
    @builtin(num_workgroups) num_wg: vec3<u32>  
) {  
     let n = arrayLength(&global_data);  
     let indices = get_indices(lid, wid, num_wg);  
     let local_idx = indices[0];  
     let global_idx = indices[1];  
     copy_global_data_to_local(n, local_idx, global_idx);  
  
     up_sweep(local_idx);  
  
     if (local_idx == 0u) {  
         local_data[WG_SIZE - 1u] = 0u;  
     }  
     workgroupBarrier();  
  
     down_sweep(local_idx);  
  
     // write out the local scan result to the global storage  
     if (global_idx < n) {  
         global_data[global_idx] = local_data[local_idx];  
     }  
}
blelloch_add_carry.wgsl
const WG_SIZE: u32 = 64u;

@group(0) @binding(0) var<storage, read_write> global_data: array<u32>;
@group(0) @binding(1) var<storage, read_write> block_sum: array<u32>;

fn linearize_workgroup_id(wid: vec3<u32>, num_wg: vec3<u32>) -> u32 {
    // linear = x + y*X + z*(X*Y)
    return wid.x + wid.y * num_wg.x + wid.z * (num_wg.x * num_wg.y);
}

@compute @workgroup_size(WG_SIZE)
fn add_carry(
    @builtin(local_invocation_id) lid: vec3<u32>,
    @builtin(workgroup_id) wid: vec3<u32>,
    @builtin(num_workgroups) num_wg: vec3<u32>,
) {
    let n_data = arrayLength(&global_data);
    let n_blocks = arrayLength(&block_sum);

    // Linear workgroup index is same as the index of the block sum
    let wg_linear = linearize_workgroup_id(wid, num_wg);
    if (wg_linear >= n_blocks) {
        return;
    }

    let global_idx = wg_linear * WG_SIZE + lid.x;
    if (global_idx >= n_data) {
        return;
    }

    let carry = block_sum[wg_linear];
    global_data[global_idx] += carry;
}

そして再帰的に累積和を当てるため、途中の計算結果を保存するバッファが必要になります。またそれに伴って、どのバッファをどの段階で操作するかを予めはめ込んだバインドグループの配列も作っておくことにします。

block_blelloch_scan.rs
// Build all required buffers + block scan bind groups for each level  
const TILE_SIZE: usize = 64;  
let mut data_buffers: Vec<wgpu::Buffer> = vec![];  
let mut bind_groups_write_sum: Vec<wgpu::BindGroup> = vec![];  
let mut elms_per_level: Vec<u32> = vec![];  
// For original data  
data_buffers.push(device.create_buffer(&wgpu::BufferDescriptor {  
    label: Some("block-sum"),  
    size: (n * size_of::<u32>()).max(4) as u64,  
    usage: wgpu::BufferUsages::STORAGE  
        | wgpu::BufferUsages::COPY_SRC  
        | wgpu::BufferUsages::COPY_DST,  
    mapped_at_creation: false,  
}));  
// Create buffers for blocks  
let mut level_elms = n;  
let mut i = 1;  
while level_elms > TILE_SIZE {  
    elms_per_level.push(level_elms as u32);  
    let num_blocks = level_elms.div_ceil(TILE_SIZE).max(1);  
    let sum_bytes = (num_blocks * size_of::<u32>()) as u64;  
    data_buffers.push(device.create_buffer(&wgpu::BufferDescriptor {  
        label: Some("block-sum"),  
        size: sum_bytes.max(4),  
        usage: wgpu::BufferUsages::STORAGE  
            | wgpu::BufferUsages::COPY_SRC  
            | wgpu::BufferUsages::COPY_DST,  
        mapped_at_creation: false,  
    }));  
  
    // bind group: (prev_level -> this_level)  
    let src = &data_buffers[i - 1];  
    let dst = &data_buffers[i];  
    bind_groups_write_sum.push(device.create_bind_group(&wgpu::BindGroupDescriptor {  
        label: Some("block-scan bind group"),  
        layout: &pipeline_write_sum.get_bind_group_layout(0),  
        entries: &[  
            wgpu::BindGroupEntry {  
                binding: 0,  
                resource: src.as_entire_binding(),  
            },  
            wgpu::BindGroupEntry {  
                binding: 1,  
                resource: dst.as_entire_binding(),  
            },  
        ],  
    }));  
  
    level_elms = num_blocks;  
    i += 1;  
}  
// The last buffer's elements number is for `block_scan_no_sum`  
elms_per_level.push(level_elms as u32);  
  
let last_buffer = &data_buffers[data_buffers.len() - 1];  
let bind_group_no_sum = device.create_bind_group(&wgpu::BindGroupDescriptor {  
    label: Some("block-scan bind group"),  
    layout: &pipeline_no_sum.get_bind_group_layout(0),  
    entries: &[wgpu::BindGroupEntry {  
        binding: 0,  
        resource: last_buffer.as_entire_binding(),  
    }],  
});  
  
// Build Add-carry bind groups  
let mut bind_groups_add_carry: Vec<wgpu::BindGroup> = vec![];  
for i in (1..data_buffers.len()).rev() {  
    bind_groups_add_carry.push(device.create_bind_group(&wgpu::BindGroupDescriptor {  
        label: Some("add-carry bind group"),  
        layout: &pipeline_add_carry.get_bind_group_layout(0),  
        entries: &[  
            wgpu::BindGroupEntry {  
                binding: 0,  
                resource: data_buffers[i - 1].as_entire_binding(),  
            },  
            wgpu::BindGroupEntry {  
                binding: 1,  
                resource: data_buffers[i].as_entire_binding(),  
            },  
        ],  
    }));  
}

シェーダー側にup-sweep/down-sweep全体が移されたことで、コマンドエンコーダ側にも変更が入ります。これまではUniformとDynamic offsetを使って管理していたステップ数がCPU側では不要になるので、事前に作っておいたバインドグループの数だけdispatchすれば完了です。大分サッパリしましたね!

block_blelloch_scan.rs
pub fn encode_scan(&self, encoder: &mut wgpu::CommandEncoder) {  
    const WG_SIZE: u32 = 64;  
    let max_dim = self.device.limits().max_compute_workgroups_per_dimension;  
  
    let mut pass = encoder.begin_compute_pass(&Default::default());  
    pass.set_pipeline(&self.pipeline_write_sum);  
  
    // apply the scan for block sums recursively until the size of the block sums array becomes smaller than one block size  
    self.bind_groups_write_sum  
        .iter()  
        .enumerate()  
        .for_each(|(i, bind_group)| {  
            let workgroups_needed = self.elms_per_level[i].div_ceil(WG_SIZE).max(1);  
            pass.set_bind_group(0, bind_group, &[]);  
            let [x, y, z] = split_dispatch_3d(workgroups_needed, max_dim);  
            pass.dispatch_workgroups(x, y, z);  
        });  
  
    // The last sums also requires scan but no need to write the new block sums since it is already fitting in one block  
    let last_idx = self.elms_per_level.len() - 1;  
    let workgroups_needed = self.elms_per_level[last_idx].div_ceil(WG_SIZE).max(1);  
    pass.set_pipeline(&self.pipeline_no_sum);  
    pass.set_bind_group(0, &self.bind_group_no_sum, &[]);  
    let [x, y, z] = split_dispatch_3d(workgroups_needed, max_dim);  
    pass.dispatch_workgroups(x, y, z);  
  
    // add carry to the previous data  
    pass.set_pipeline(&self.pipeline_add_carry);  
    for level in (1..self.data_buffers.len()).rev() {  
        let bind_group = &self.bind_groups_add_carry[self.data_buffers.len() - 1 - level];  
        let block_len = self.elms_per_level[level - 1];  
        let workgroups_needed = block_len.div_ceil(WG_SIZE).max(1);  
  
        pass.set_bind_group(0, bind_group, &[]);  
        let [x, y, z] = split_dispatch_3d(workgroups_needed, max_dim);  
        pass.dispatch_workgroups(x, y, z);  
    }  
}

ベンチマーク

グローバルメモリへのアクセスは4N程度に減っているはずですが、どうかな!

CPU vs Blelloch global vs Blelloch block

グローバルメモリverと比べて若干早くなりCPUとほぼ同程度、という結果になりました。

ちなみにGPU Gems 3で触れられている1スレッドでの2もしくは4要素の同時処理も実装してみたのですが、寧ろ遅くなる結果になりました。またワークグループサイズは32~256を試しましたが、64が最速でした。これらおそらくworkgroupBarrierの数が多い分、ワークグループサイズが大きくなったり一つのスレッドで処理する要素数が増えると、スレッド間での処理完了までの時間差が大きくなり、バリアまでの待ちが長くなってしまっているのかなと思われます。(ここは想像で書いていて不確かです)
同様にバンクコンフリクト回避のためにパディングを入れても見ましたが、これでもほとんど速度に変化はありませんでした。これも正直謎ですが、多分バックエンドによって挙動が結構変わると思われます。Apple Siliconではバンクコンフリクトがそもそもあまり影響しないのかもしれません。

兎も角、最適と思われたBlocked Blelloch scanでもCPUに勝利とは言えない結果になりました。

Scan by subgroups

さてWebGPUにはsubgroupというfeatureがあり、環境によっては使うことができます。例えばwgpu+MetalやChromiumでも一部機能がサポートされています。これは各バックエンドのwarpやsimdgroup系のAPIにマッピングされます。これによって同じsubgroupの要素に対してより高速な演算を行うことが可能になります。GPUではある程度の数まとまったスレッドを同時に動かしますが、そのスレッド達はシャッフルによってデータ交換をレジスタ直読みでできるみたいです。

実装

そしてsubgroupにはsubgroupExclusiveAddというメソッドがあります!!!Blelloch scanまで実装してから気づきました。その名の通りsubgroup内で累積和を取ってくれるメソッドです。ほな最初からこれで良かったのでは…と思いましたが、概ねBlocked Blelloch scan用に実装したブロック化はそのまま使えるので学習コストも含めて良しとします。ブロック化の図を見返してもらうと、ブロック内で累積和を取ってブロック合計を別のバッファに詰めて、という処理の繰り返しです。つまりシェーダーのBlelloch scanを使っている箇所をsubgroupExclusiveAddに置き換えればよいわけですね。

ただしsubgroupのサイズはバックエンドによりますが32か64程度で、これはworkgroupのサイズよりも大抵小さくなります。よってsubgroupExclusiveAddの結果に対しても同様にブロック合計->累積和->足し戻しの処理が必要になります。基本的なデータの流れは最初のデータからsubgroupサイズに分解されるまで変わらないので追いやすいと思います。

subgroup_block_scan.wgsl
@compute @workgroup_size(WG_SIZE)  
fn block_scan_write_sum(  
    @builtin(local_invocation_id) lid: vec3<u32>,  
    @builtin(workgroup_id) wid: vec3<u32>,  
    @builtin(num_workgroups) num_wg: vec3<u32>,  
    @builtin(subgroup_size) sg_size: u32, // maybe 32 or 64, depends on the GPU  
    @builtin(subgroup_invocation_id) sg_lane: u32, // 0..sg_size, most probably 0..32 on Metal
    @builtin(subgroup_id) sg_id: u32, // 0..workgroup_size/subgroup_size, most probably 0..4 on Metal
) {  
    let n = arrayLength(&global_data);  
  
    let wg_linear = linearize_workgroup_id(wid, num_wg);  
    let global_idx = wg_linear * WG_SIZE + lid.x;  
    let in_range = global_idx < n;  
    var v = 0u;  
    if (in_range) {  
        v = global_data[global_idx];  
    }  
  
    // exclusive scan result in the same subgroup until this element  
    let sg_prefix = subgroupExclusiveAdd(v);  
    // calculate the sum of all elements in the subgroup.  
    // The same result will be returned for the same subgroup, no matter which lane we are in.    let sg_sum = subgroupAdd(v);  
    // Store the sum from each subgroup into workgroup shared  
    if (sg_lane == 0u) {  
        local_data[sg_id] = sg_sum;  
    }  
    workgroupBarrier();  
  
    // Build offsets to collect the each subgroup's scan result  
    let num_sg = (WG_SIZE + sg_size - 1u) / sg_size;  
    if (lid.x == 0u) {  
        // run exclusive scan on the subgroup sum results array  
        var sg_sum_total = 0u;  
        for (var i = 0u; i < num_sg; i = i + 1u) {  
            let tmp = local_data[i];  
            local_data[i] = sg_sum_total;  
            sg_sum_total = sg_sum_total + tmp;  
        }  
        // store the block sum for the next block scan  
        let n_blocks = arrayLength(&block_sum);  
        if (wg_linear < n_blocks) {  
            block_sum[wg_linear] = sg_sum_total;  
        }  
    }  
    workgroupBarrier();  
  
    // Add carries from each subgroup to the subgroup prefix  
    if (in_range) {  
        global_data[global_idx] = local_data[sg_id] + sg_prefix;  
    }  
}

ベンチマーク

CPU vs Blelloch block vs subgroup
大満足!!

終わりに諸々

  • 累積和の速度だけ見るとsubgroupを使えば速いですが、累積和と同時に他にも何か計算もしたいなどのシチュエーションだと、Blellochをベースに改造すると良いようなシーンもあるかもしれません。具体例は思いつきませんが
  • Blellochよりも高速なDecoupled look-backというアルゴリズムもあるみたいです。
  • 今回のベンチマークはsubmit〜デバイスが再度空き状態になるまでの時間を取ったので、パイプラインやバインドグループを作成するオーバーヘッドは含んでいません。なので単純に累積和を取るだけが目的なら余程要素数が大きくないとGPUを使うメリットは薄いかもです。反面GPUで他の処理と組み合わせて使う場面などは大いにあると思います。
  • ブロックに分けるスキャンはパフォーマンス向上と対応できる配列サイズの調整のため実質必須になると思われますが、中間計算結果を保存する分メモリの使用量が大きくなります。なにかうまい方法がありそうな気がしますが、そこまでは調べられていません。
  • WebGPUおもしろーい!!

参考文献

https://ja.wikipedia.org/wiki/累積和
https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda
https://webgpufundamentals.org/webgpu/lessons/ja/webgpu-compute-shaders-histogram.html
https://github.com/gfx-rs/wgpu/issues/5555
https://www.w3.org/TR/WGSL/#subgroup
https://www.youtube.com/watch?v=lavZl_wEbPE

Discussion