🐥

HWの制約を考慮したモデル設計によるパフォーマンス改善

2025/02/14に公開

この記事は以下の論文のまとめ記事です。ModernBERTで採用していたアーキテクチャのパラメータ設計が気になったので参考文献を紐解きました。

paper: The Case for Co-Designing Model Architectures with Hardware, Jan 2024

https://arxiv.org/abs/2401.14489

この記事で引用されている図の著作権は全て著者らに帰属します。

Abstract

  • 機械学習モデルにおける計算のスループットはHWの制約をかなり受ける
  • Transformerモデルの処理の大部分はGEMMに割かれているので、GEMMの行列のサイズを最適に設計することで効果的にスループットを向上できる
  • 本論文ではNVIDIAのGPUの機能に特化したモデルのアーキテクチャの設計指針と具体的な設計ルールが示されている

用語

  • GEneral Matrix Multiplication (GEMM): 行列の乗算

背景

課題: HWの制約を考慮して決めないとモデルの設計は最適にならない

  • これまでHWの働きを意識せずに設計されたモデルがコピーされ続けてきたことで、「最適でない設計」がデファクトになっている(例: 研究者は先行研究との比較のためなるべく抜本的にパラメータを変えない傾向がある)
  • 例えば、GPT-Neo, OPT, RedPajama, PythiaなどはどれもGPT3のアーキテクチャを継承しているが、MHAのheadの次元を64の倍数にするだけで、同規模パラメータサイズでのスループットが20%も向上する
  • →HWの制約を考慮することで既存のモデル設計には改善の余地がある。本論文ではその設計のレシピが示されている。

どこに計算の大部分が割かれているか?

  • TransformerではGEMMに大部分の計算が割かれている(small: 68.3%, large: 94.9%)
  • →GEMMの最適化に注力すればモデル全体のスループットを効果的に改善できる

NVIDIAのGPUにおけるGEMMの仕組み

  • 大きな行列をタイルに分割し、それぞれのタイル同士の行列積を並列計算している

特徴

  • 行列の大きさはタイルサイズの倍数に設計する方が効率的に処理できる
  • 並列処理するタスクは、一度にStreaming Manager(SM)の最大数までしか計算できない。 -> 並列計算するタイルの数をSM数の倍数にする方がidleが発生しにくい

手法

設計方針

  1. Tensor Core Requirement: GEMMのinput/output dimを128の倍数にする(FP16の場合、64の倍数)
  2. Tile Quantization: 重み行列のサイズを 128x256 の倍数にする
  3. Wave Quantization: ブロック数をGPUのSM数で割り切れるようにする(典型的なGPUのSM数についてはAppendix Aを参照)。

アーキテクチャの設計ルール

  • vacabulary sizeは64で割り切れるようにする
  • microbatch size bは可能な限り大きくとる
  • b \cdot s, h / a, h / t は64、難しければそれ以下のなるべく大きな2の冪乗数で割り切れるようにする
  • (b\cdot a)/t は整数になるようにする
  • t は可能な限り小さくとる

記法

  • a: #attention heads
  • b: microbatch size
  • h: hidden dim
  • L: #transformer layers
  • s: sequence length
  • t: tensor-parallel size
  • v: vocaburary size

所感

  • 今までなんとなく32とか64の倍数でパラメータを決めてたが、背景にある仕組みを理解できたのは良かった
  • 小規模モデルならここまでガチガチに設計しなくても良いかもしれないが、学習コストのかかる大規模モデルを設計する際には頭に入れておいても良い知見かも
  • 論文でFlashAttention v2の使用を勧めているように、エコシステムの進化によってこういったことを意識しなくてもまずまずのパフォーマンスが出るような環境に徐々になってきているのかもしれない

Reference

Appendix

A. Kaggleや市販のグラフィックカードで使われているGPUのSM数

GPU #SMs
H100 [1] 144
A100 [1] 108
V100 [1] 80
RTX4090 [2a] 128
RTX3090 Ti [2b] 84
NVIDIA L4 [3a] 58
Tesla P100 PCle 16GB [3b] 56
Tesla T4 [3c] 40
GitHubで編集を提案

Discussion