【AWS Trainium 50本ノック #5】分散学習の基礎知識編
第 5 章 分散学習の基礎知識編
本章では以下を仮定します。
- AWS Trainium の基礎知識(第 2 章の内容)
- ニューラルネットワークの基礎的理解
問題 (33-37)
複数のメモリデバイスを並列的に使用して学習を行うことを「分散学習」と呼びます。
(※デバイス: ここでは「学習を実行する際の最小単位」のことを指します。Trainiumの場合は、Neuronコアを指します。またTrainium2の場合は物理Neuronコアと論理Neuronコアがありますが、論理Neuronコアを指します。)
分散学習が必要となるのは以下の時です。
- 単一のデバイスでの学習では遅すぎる場合
- モデルが大きすぎて、モデル全体が単一のデバイスに乗り切らない場合
| 大分類 | 小分類 | 概要 | aを解決できる? | bを解決できる? |
|---|---|---|---|---|
| データ並列 | データ並列(DP) | ミニバッチ内のデータを複数デバイスで分担 | Yes | No |
| モデル並列 | テンソル並列(TP) | パラメータテンソルを縦または横に分割して複数デバイスで分担 | Yes | Yes |
| パイプライン並列(PP) | モデルの第●層〜▲層をデバイス1に、第▲層〜◾️層をデバイス2に……という感じで分担 | Yes | Yes |
これら3つが主な並列手法です。これらは併用可能です。以下、順に解説します。
※ 重要なこととして、単一のデバイスに乗り切らないような大きなモデルの学習を行うためにはモデル並列(TPあるいはPP)が必須となります。Trnの場合、1コア=16GBで学習できないモデル(大体のLLMはこれに該当します)ではモデル並列が必須です。
※ 前章で行った学習では、DP=1, TP=32, PP=1 でした。(このように、DP, TP, PP という表記で、それぞれの並列デバイス数自体を表す場合があります。)
-
DP=2, TP=32, PP=8の設定で学習したいです。いくつの Neuron コアが必要になりますか?また、trn1.32xlargeインスタンスを使用する場合、何ノード必要になりますか?解答
2 × 32 × 8 = 512 コア必要です。また、
trn1.32xlarge1ノードには Neuron コアが 32 コア含まれているため、512 ÷ 32 = 16 ノード必要となります。
データ並列 (Data Parallelism; DP)
ミニバッチ内のデータを、デバイス間で分担します。
- 概要
- データ並列は、モデルの重みをすべてのデバイスで共有しつつ、異なる入力データをそれぞれのデバイスに渡して計算を並列化する手法です。各デバイスは自分に割り当てられたマイクロバッチに対して順伝播・逆伝播を行い、その後、各デバイス間で勾配を同期することで、一貫性のある重み更新を実現します。大規模なバッチサイズで学習したいとき、もっとも基本的で導入しやすい並列化戦略です。
- マイクロバッチサイズ・グローバルバッチサイズについて
- グローバルバッチサイズ(global batch size)は、すべてのデバイスを通じて「勾配法の1ステップ」あたりに処理されるサンプルの総数です(つまり、ハードウェアのことを忘れてアルゴリズムにのみ着目した時のミニバッチサイズのことです)。そして、これを各デバイスに分配したものがマイクロバッチサイズ(micro batch size)となります。
- モデル並列は用いずにデータ並列だけを用いて、グローバルバッチサイズ 256 で学習をしたいです。いま、
trn1.32xlargeが2ノード利用可能です。この時、DP(=分担するデバイス数)はいくつで、マイクロバッチサイズはいくつとなりますか?解答
trn1.32xlargeには、32個のNeuronCoreが搭載されており、1デバイス = 1NeuronCoreとするのが基本です。2ノードあるため、利用可能なデバイス数(DPサイズ)は 32 x 2 = 64、よって、マイクロバッチサイズは 256 ÷ 64 = 4 となります。
なお、マイクロバッチサイズが大きすぎて単一デバイスのメモリに乗り切らない場合に、さらに勾配蓄積 (gradient accumulation) と呼ばれる方法でマイクロバッチサイズを小さくすることができます。勾配蓄積は、複数ステップ分の勾配を蓄積したのちに一度だけ重み更新を行う方法です。これにより、メモリ消費を抑えつつ、効果的なグローバルバッチサイズを維持することができます。
勾配蓄積を行う場合、forward-backward 1回ごとのステップを「マイクロステップ」、重み更新ごとのステップを「グローバルステップ」と呼び分ける場合があります。
- 上記の問題で求めたマイクロバッチサイズで学習を試みたところ、OOMとなってしまいました。そこで、勾配蓄積を使用して、マイクロバッチサイズを1に変更して学習しようと思います。この場合、勾配蓄積ステップ数はいくつにする必要がありますか?
解答
もとのグローバルバッチサイズは 256 で、DPサイズは 64 です。マイクロバッチサイズを1にすると、各デバイスが1ステップで処理するデータ数は1になります。よって、必要な勾配蓄積ステップ数は、
勾配蓄積ステップ数 = グローバルバッチサイズ ÷ DPサイズ ÷ 新しいマイクロバッチサイズ = 256 ÷ 64 ÷ 1 = **4**となります。すなわち、各デバイスで4ステップ分の勾配を蓄積するたびに1回の重み更新を行う設定にすれば、もとのグローバルバッチサイズを保つことができます。
これらの関係をまとめると、以下のようになります:
グローバルバッチサイズ = マイクロバッチサイズ × 勾配蓄積ステップ数 × DP
テンソル並列 (Tensor Parallelism; TP)
-
概要
- テンソル並列は、モデル中の単一のパラメータテンソルを複数のデバイスに分割して処理する手法です。たとえば、大きな行列の掛け算において、行や列ごとに分割して各デバイスが部分計算を担当します。これにより、1台のデバイスにすべての重みを載せる必要がなくなり、大規模モデルの学習が可能になります。
- 一般に
という行列積を計算したい状況を考えます。ただしY = XA の行数と列数が大きく、単一のデバイスでは処理できないとしましょう。計算を並列化するために、行列A をA 個に分割することを考えます。p -
と横並びに分割する場合:各デバイスでA = [A_1, A_2, \ldots, A_p] を計算し、最終的にY_i = X A_i と結合することで、Y = [Y_1, Y_2, \ldots, Y_p] を計算できます。この計算方法を列並列 (Column Parallel) と呼びます。Y -
と縦並びに分割する場合:A = [A_1, A_2, \ldots, A_p]^\top もX と分割した上で、各デバイスでX = [X_1, X_2, \ldots, X_p] を計算し、最終的にY_i = X_i A_i と総和をとることで、Y = Y_1 + Y_2 + \cdots + Y_p を計算できます。この計算方法を行並列 (Row Parallel) と呼びます。Y
-
-
例えば、2層の全結合ニューラルネットワークでは、第1層の重みを「列並列」で保持し、第2層の重みを「行並列」で保持することで、第1層の行列積計算後の「各デバイスの計算結果の結合」と第2層の行列積計算開始前の「各デバイスへの入力分割」をスキップすることができ、効率的に計算を進めることができます。
2層の全結合ニューラルネットワーク:詳細
入力
(\mathbf{X} \in \mathbb{R}^{B \times d} : バッチサイズ)に対して、以下のような計算を行うモデルを考えます:B \mathbf{H} = \phi(\mathbf{X} \mathbf{W}_1^\top), \quad \mathbf{Y} = \mathbf{H} \mathbf{W}_2^\top ただし
-
は第1層の重み(出力次元\mathbf{W}_1 \in \mathbb{R}^{h \times d} 、入力次元h )d -
はReLUなどの活性化関数(バッチごとに要素単位で作用)\phi -
は第2層の重み(出力次元\mathbf{W}_2 \in \mathbb{R}^{o \times h} )o -
、\mathbf{H} \in \mathbb{R}^{B \times h} \mathbf{Y} \in \mathbb{R}^{B \times o}
とします。
-
第1層(Column Parallel)
- 重み
を**出力次元方向(行方向)**に2分割:\mathbf{W}_1
\mathbf{W}_1 = \begin{bmatrix} \mathbf{W}_1^{(0)} \\ \mathbf{W}_1^{(1)} \end{bmatrix}, \quad \mathbf{W}_1^{(i)} \in \mathbb{R}^{(h/2) \times d} - 各デバイスは
全体を共有し、それぞれ次の出力を計算:\mathbf{X}
\mathbf{H}^{(i)} = \phi(\mathbf{X} (\mathbf{W}_1^{(i)})^\top), \quad \mathbf{H}^{(i)} \in \mathbb{R}^{B \times (h/2)} - 全デバイスで
を横方向(列方向)に連結して\mathbf{H}^{(0)}, \mathbf{H}^{(1)} を得る(この連結は、第2層での\mathbf{H} の分割処理(※)とともに省略可能):\mathbf{H}
\mathbf{H} = \left[ \mathbf{H}^{(0)} \;\; \mathbf{H}^{(1)} \right] - 重み
-
第2層(Row Parallel)
- 重み
を**入力次元方向(列方向)**に2分割:\mathbf{W}_2
\mathbf{W}_2 = \left[ \mathbf{W}_2^{(0)} \;\; \mathbf{W}_2^{(1)} \right], \quad \mathbf{W}_2^{(i)} \in \mathbb{R}^{o \times (h/2)} -
も対応して2分割し(※)、各デバイスが:\mathbf{H}
\mathbf{Y}^{(i)} = \mathbf{H}^{(i)} (\mathbf{W}_2^{(i)})^\top, \quad \mathbf{Y}^{(i)} \in \mathbb{R}^{B \times o} - 最後に、**各デバイスの出力を加算(AllReduce)**して最終出力
を得る:\mathbf{Y}
\mathbf{Y} = \mathbf{Y}^{(0)} + \mathbf{Y}^{(1)} - 重み

-
-
💡テンソル並列について、さらに詳細を知りたい方は Turing 様のこちらの記事がおすすめです。
-
注意点
- 分割数(=TPサイズ)は、対象となるテンソルの分割次元(通常は「隠れ次元」)に沿ったサイズの約数である必要があります。たとえば、サイズが4096の次元に沿ってテンソルを分割したい場合、TPサイズとして可能なのは 4096 の約数のみとなります。
- 自己注意機構を持つモデルの場合、各アテンションヘッドごとの計算は並列的ですが、各アテンションヘッドの内部で行われる計算をさらにテンソル並列化することは困難です。そのため、TPはアテンションヘッドの個数の約数である必要があります。
- GQA(Grouped Query Attention)を採用しているモデルの場合、アテンションヘッド数(=Qヘッド数)とKVヘッド数が異なります(後者は前者の約数となっています)。この場合、TPは「KVヘッド数」の約数でもある必要があります。しかし、これではTPとして選べる値が大幅に制限されてしまいます。そこで NxD では「KVレプリケータ(
KV_REPLICATOR)」という仕組みが用意されています。これが設定されている場合、KVヘッドの重みはKV_REPLICATOR倍に複製されてメモリ上に保持されます。これにより、TPは「KVヘッド数」の約数ではなく「(KVヘッド数) *KV_REPLICATOR」の約数であれば良くなります。
- GQA(Grouped Query Attention)を採用しているモデルの場合、アテンションヘッド数(=Qヘッド数)とKVヘッド数が異なります(後者は前者の約数となっています)。この場合、TPは「KVヘッド数」の約数でもある必要があります。しかし、これではTPとして選べる値が大幅に制限されてしまいます。そこで NxD では「KVレプリケータ(
-
Qwen/Qwen3-8B の config.json には
"vocab_size": 151936, "hidden_size": 4096, "intermediate_size": 12288, "num_attention_heads": 32, "num_key_value_heads": 8と記載があります。これらはそれぞれ「語彙サイズ」「隠れ次元サイズ」「MLP層の中間層サイズ」「アテンションヘッド数」「KVヘッド数」を表します。TPをなるべく大きな値に設定したい場合、どのように設定すべきですか?解答
TP は「32」と「8*KV_REPLICATOR」の約数である必要があります(それ以外の数は全て32の倍数です)。なるべく TP を大きく設定したい場合、KV_REPLICATOR=4 と設定して TP=32 とするのが良いです。
パイプライン並列 (Pipeline Parallelism; PP)
-
概要
- パイプライン並列は、モデル全体のレイヤーを分割して、それぞれ異なるデバイスで担当させる並列手法です。各マイクロバッチがパイプライン状に複数ステージを通過していくことで、同時に複数のマイクロバッチを並列実行できます。
-
注意点
- ステージ数とマイクロバッチ数の組み合わせによっては、「bubble(泡)」と呼ばれるアイドル時間が多くなり、計算資源が無駄になります。

-
Qwen2.5-VL-72B-Instruct は「視覚エンコーダ」と「テキストデコーダ」が直列に繋がった構成となっており、前者にはTransformer層が32層(直列的に)含まれており、後者にはTransformer層が64層(直列的に)含まれています。このLLMを、DP=4のデータ並列、およびモデル並列として全体のTPを8、「視覚エンコーダ」のPPを4、「テキストデコーダ」のPPを16として学習したいです。
trn1.32xlargeが何ノード必要になりますか?解答
-
4 * 8 * (4 + 16) = 640コア必要になるため、必要なノード数は640 ÷ 32 = 20となります。
-
DP, TP, PP を組み合わせた分散学習により、学習を高速化したり、1デバイスに乗り切らないような巨大なモデルを学習したりすることが可能となります。
- ⚠️ 注意点
- Neuron チップにおいては、コア間の通信が完全な全結合(all-to-all)ではなく、特定のパターン(詳細)に制限されているため、任意の DP, TP, PP の組み合わせが可能とは限りません。
- 例:
DP=1, TP=1, PP=32の設定でtrn1.32xlarge1ノード(=32コア)で学習することはできません。コア0→コア1→コア2→……→コア31と通信する必要がありますが、Trn1 チップには「コア7→コア8」の結合が存在しないためです。
- 例:
- また、GQA 採用モデルで
KV_REPLICATOR > 1を設定する場合、この KV_REPLICATOR に応じたコア間通信が生じます。これもまた、コア間の通信が全結合ではないことが理由で、一部の KV_REPLICATOR 値が使用できない場合があります。
- Neuron チップにおいては、コア間の通信が完全な全結合(all-to-all)ではなく、特定のパターン(詳細)に制限されているため、任意の DP, TP, PP の組み合わせが可能とは限りません。
補足:「シーケンス並列」について
- 「テンソル並列」の補助として「シーケンス並列」と呼ばれる技法が使用される場合があります。テンソル並列では パラメータ をTP個のデバイスに分割して保持しますが、それに加えて 中間状態(アクティベーション) もTP個のデバイスに分割して保持するのが「シーケンス並列」です。これにより、各デバイスのメモリ使用量をさらに削減することができます。デバイス間での通信は複雑になりますが、後述の通りNxDのライブラリによりラップされているため簡単に有効化可能です。
- NxD ライブラリにおいて「シーケンス並列」を有効化する場合、アクティベーション(i.e. 各層への入力テンソル・各層からの出力テンソル)は、通常とは異なる以下のような方式でメモリに保持されます(以降ではこれを「シーケンス並列モード」と呼びます):
- オリジナルのアクティベーションテンソルではなく、それのシーケンス方向の次元をTP個に分割したもの(のうちの1断片)が、各デバイスに保持されます(それらは一般に異なります)。
- 通常、アクティベーションのテンソルは
(バッチサイズ, シーケンス長, …)という軸の順番ですが、シーケンス並列モードでは、0-dim と 1-dim を転置した(シーケンス長, バッチサイズ, …)という軸順で保持されます。 - 両者をまとめると、各デバイスに保持されるテンソルの shape は(通常だと
[bsz, q_len, …]であるところが)[q_len//TP, bsz, …]となります。
-
本50本ノックの内容を監修くださった AWS の常世様に、感謝を申し上げます。 ↩︎
Discussion