🐳

AIと読むDeepSeek-V3 Technical Report② - Architecture -

2025/01/30に公開

英語の論文を日本語にして読んでいきたいです。

2. アーキテクチャ

まず、効率的な推論のためのMulti-head Latent Attention(MLA)と経済的なトレーニングのためのDeepSeekMoEを特徴とするDeepSeek-V3の基本的なアーキテクチャを紹介します。次に、評価ベンチマークでの全体的なパフォーマンスを向上させることが観察されたMulti-Token Prediction(MTP)トレーニング目標を提示します。明示的に言及されていないその他の細部の設定については、DeepSeek-V3はDeepSeek-V2の設定に従います。

図2
図2: DeepSeek-V3の基本アーキテクチャの図。DeepSeek-V2に従い、効率的な推論と経済的なトレーニングのためにMLAとDeepSeekMoEを採用しています。

2.1 基本アーキテクチャ

DeepSeek-V3の基本アーキテクチャは、依然としてTransformerフレームワークの範囲内です。効率的な推論と経済的なトレーニングのために、DeepSeek-V3はDeepSeek-V2で徹底的に検証されたMLAとDeepSeekMoEも採用しています。DeepSeek-V2と比較して、例外は、ロードバランスを確保するための努力によって引き起こされるパフォーマンスの低下を緩和するために、DeepSeekMoEに補助ロスフリーのロードバランシング戦略を導入したことです。図2は、DeepSeek-V3の基本アーキテクチャを示しており、このセクションではMLAとDeepSeekMoEの詳細を簡単にレビューします。

2.1.1 Multi-Head Latent Attention

Attentionのために、DeepSeek-V3はMLAアーキテクチャを採用しています。
d を埋め込み次元、 n_h を注意ヘッドの数、 d_h をヘッドごとの次元、 𝐡_t ∈ ℝ^d を特定の注意層におけるt番目のトークンの注意入力とします。MLAの核となるのは、推論中のKey-Value(KV)キャッシュを削減するための注意キーと値の低ランク結合圧縮です。

\begin{align} 𝐜_t^{KV} = W^{DKV} 𝐡_t \end{align}
\begin{align} [𝐤_{t,1}^C; 𝐤_{t,2}^C; …; 𝐤_{t,n_h}^C] = 𝐤_t^C = W^{UK} 𝐜_t^{KV} \end{align}
\begin{align} 𝐤_t^R = RoPE(W^{KR} 𝐡_t) \end{align}
\begin{align} 𝐤_{t,i} = [𝐤_{t,i}^C; 𝐤_t^R] \end{align}
\begin{align} [𝐯_{t,1}^C; 𝐯_{t,2}^C; …; 𝐯_{t,n_h}^C] = 𝐯_t^C = W^{UV} 𝐜_t^{KV} \end{align}

ここで、 𝐜_t^{KV} ∈ ℝ^{d_c} はキーと値の圧縮された潜在ベクトルです。 d_c (≪ d_h n_h) はKV圧縮次元を示します。 W^{DKV} ∈ ℝ^{d_c × d} はダウンプロジェクション行列を示します。 W^{UK}, W^{UV} ∈ ℝ^{d_hn_h × d_c} は、それぞれキーと値のアッププロジェクション行列です。 W^{KR} ∈ ℝ^{d_h^R × d} はRotary Positional Embedding(RoPE)を運ぶ分離されたキーを生成するために使用される行列です。 RoPE(⋅) はRoPE行列を適用する操作を示します。 [⋅;⋅] は連結を示します。MLAでは、青いボックスで囲まれたベクトル(つまり、 𝐜_t^{KV}𝐤_t^R )のみを生成中にキャッシュする必要があり、これにより、標準のMulti-Head Attention(MHA)に匹敵するパフォーマンスを維持しながら、KVキャッシュが大幅に削減されます。

注意クエリについても、低ランク圧縮を実行し、トレーニング中のアクティベーションメモリを削減できます。

\begin{align} 𝐜_t^Q = W^{DQ} 𝐡_t \end{align}
\begin{align} [𝐪_{t,1}^C; 𝐪_{t,2}^C; …; 𝐪_{t,n_h}^C] = 𝐪_t^C = W^{UQ} 𝐜_t^Q \end{align}
\begin{align} [𝐪_{t,1}^R; 𝐪_{t,2}^R; …; 𝐪_{t,n_h}^R] = 𝐪_t^R = RoPE(W^{QR} 𝐜_t^Q) \end{align}
\begin{align} 𝐪_{t,i} = [𝐪_{t,i}^C; 𝐪_{t,i}^R] \end{align}

ここで、 𝐜_t^Q ∈ ℝ^{d_c'} はクエリの圧縮された潜在ベクトルです。 d_c' (≪ d_h n_h) はクエリ圧縮次元を示します。 W^{DQ} ∈ ℝ^{d_c' × d}W^{UQ} ∈ ℝ^{d_h n_h × d_c'} は、それぞれクエリのダウンプロジェクション行列とアッププロジェクション行列です。 W^{QR} ∈ ℝ^{d_h^R n_h × d_c'} はRoPEを運ぶ分離されたクエリを生成するための行列です。

最終的に、注意クエリ (𝐪_{t,i})、キー (𝐤_{j,i})、および値 (𝐯_{j,i}^C) が組み合わされて、最終的な注意出力 𝐮_t が得られます。

\begin{align} 𝐨_{t,i} = ∑_{j=1}^t Softmax_j(\frac{𝐪_{t,i}^T 𝐤_{j,i}}{\sqrt{d_h + d_h^R}}) 𝐯_{j,i}^C \end{align}
\begin{align} 𝐮_t = W^O [𝐨_{t,1}; 𝐨_{t,2}; …; 𝐨_{t,n_h}] \end{align}

ここで、 W^O ∈ ℝ^{d × d_h n_h} は出力プロジェクション行列を示します。

2.1.2 補助ロスフリーなロードバランシングを備えたDeepSeekMoE

DeepSeekMoEの基本アーキテクチャ。
Feed-Forward Network(FFN)の場合、DeepSeek-V3はDeepSeekMoEアーキテクチャ(Dai et al., 2024)を採用しています。GShard(Lepikhin et al., 2021)のような従来のMoEアーキテクチャと比較して、DeepSeekMoEはより細粒度の専門家を使用し、一部の専門家を共有専門家として分離します。 𝐮_t をt番目のトークンのFFN入力とすると、FFN出力 𝐡_t' は次のように計算されます。

\begin{align} 𝐡_t' = 𝐮_t + ∑_{i=1}^{N_s} FFN_i^{(s)}(𝐮_t) + ∑_{i=1}^{N_r} g_{i,t} FFN_i^{(r)}(𝐮_t) \end{align}
\begin{align} g_{i,t} = \frac{g_{i,t}'}{∑_{j=1}^{N_r} g_{j,t}'} \end{align}
\begin{align} g_{i,t}' = \begin{dcases} s_{i,t}, s_{i,t} ∈ Topk(\{ s_{j,t} | 1 ⩽ j ⩽ N_r \}, K_r) \\ 0, otherwise \end{dcases} \end{align}
\begin{align} s_{i,t} = Sigmoid(𝐮_t^T 𝐞_i) \end{align}

ここで、 N_sN_r は、それぞれ共有エキスパートとルーティングされたエキスパートの数を示します。 FFN_i^{(s)}(⋅)FFN_i^{(r)}(⋅) は、それぞれi番目の共有エキスパートとi番目のルーティングされたエキスパートを示します。 K_r はアクティブなルーティングされたエキスパートの数を示します。 g_{i,t} はi番目のエキスパートのゲーティング値です。 s_{i,t} はトークンとエキスパートのアフィニティです。 𝐞_i はi番目のルーティングされたエキスパートの重心ベクトルです。 Topk(⋅, K) は、t番目のトークンとすべてのルーティングされたエキスパートに対して計算されたアフィニティスコアの中で、K個の最高のスコアを含むセットを示します。DeepSeek-V2とは少し異なり、DeepSeek-V3はアフィニティスコアを計算するためにシグモイド関数を使用し、すべての選択されたアフィニティスコア間で正規化を適用してゲーティング値を生成します。

補助ロスフリーなロードバランシング。
MoEモデルの場合、専門家の負荷がアンバランスになると、ルーティングの崩壊につながり、専門家の並列処理があるシナリオで計算効率が低下します。従来のソリューションでは、通常、アンバランスな負荷を回避するために補助ロスに依存しています。ただし、補助ロスが大きすぎると、モデルのパフォーマンスが損なわれる可能性があります。負荷バランスとモデルパフォーマンスの間でより良いトレードオフを達成するために、負荷バランスを確保するための補助ロスフリーの負荷分散戦略を先駆的に採用します。具体的には、各エキスパートにバイアス項 b_i を導入し、対応するアフィニティスコア s_{i,t} に追加して、上位Kのルーティングを決定します。

\begin{align} g_{i,t}' = \begin{dcases} s_{i,t}, s_{i,t} + b_i ∈ Topk(\{s_{j,t} + b_j | 1 ⩽ j ⩽ N_r\}, K_r) \\ 0, otherwise \end{dcases} \end{align}

バイアス項はルーティングにのみ使用されることに注意してください。FFN出力と掛けられるゲーティング値は、元の親和性スコア s_{i,t} から引き続き導出されます。トレーニング中、各トレーニングステップのバッチ全体の専門家の負荷を監視し続けます。各ステップの終わりに、対応するエキスパートが過負荷になっている場合は、バイアス項を γ だけ減らし、対応するエキスパートが負荷不足になっている場合は、 γ だけ増やします。ここで、 γ はバイアス更新速度と呼ばれるハイパーパラメータです。動的な調整を通じて、DeepSeek-V3はトレーニング中にバランスの取れた専門家の負荷を維持し、純粋な補助ロスを通じて負荷バランスを促進するモデルよりも優れたパフォーマンスを達成します。

補完的なシーケンス単位の補助ロス。
DeepSeek-V3は主に負荷バランスのために補助ロスフリー戦略に依存していますが、単一のシーケンス内の極端なアンバランスを防ぐために、補完的なシーケンス単位のバランスロスも採用しています。

\begin{align} \mathcal{L}_{Bal} = α ∑_{i=1}^{N_r} f_i P_i \end{align}
\begin{align} f_i = \frac{N_r}{(K_r T)} ∑_{t=1}^T 𝟙(s_{i,t} ∈ Topk(\{s_{j,t} | 1 ⩽ j ⩽ N_r\}, K_r)) \end{align}
\begin{align} s_{i,t}' = \frac{s_{i,t}}{∑_{j=1}^{N_r} s_{j,t}} \end{align}
\begin{align} P_i = \frac{1}{T} ∑_{t=1}^T s_{i,t}' \end{align}

ここで、バランスファクター α はハイパーパラメータであり、DeepSeek-V3には非常に小さい値が割り当てられます。 𝟙(⋅) はインジケーター関数を示します。 T はシーケンス内のトークンの数を示します。シーケンス単位のバランスロスは、各シーケンスでの専門家の負荷がバランスされるように促します。

ノード制限ルーティング。
DeepSeek-V2で使用されているデバイス制限ルーティングと同様に、DeepSeek-V3もトレーニング中の通信コストを制限するために制限されたルーティングメカニズムを使用しています。要するに、各トークンは、各ノードに分散された専門家の上位 \frac{K_r}{M} のアフィニティスコアの合計に従って選択された、最大 M 個のノードに送信されることを保証します。この制約の下で、当社のMoEトレーニングフレームワークは、ほぼ完全な計算と通信の重複を達成できます。

トークンドロップなし。
効果的な負荷分散戦略により、DeepSeek-V3はその完全なトレーニング中に良好な負荷分散を維持します。したがって、DeepSeek-V3はトレーニング中にトークンをドロップしません。さらに、推論の負荷分散を確保するために特定の展開戦略も実装しているため、DeepSeek-V3は推論中にもトークンをドロップしません。

図3
図3: Multi-Token Prediction(MTP)実装の図。各深さで各トークンの予測に対して完全な因果チェーンを維持しています。

2.2 Multi-Token Prediction

Gloeckle et al. (2024) に触発されて、DeepSeek-V3 の Multi-Token Prediction(MTP)目標を調査および設定しました。これにより、予測範囲を各位置で複数の将来のトークンに拡張します。一方、MTP 目標はトレーニング信号を密にし、データ効率を向上させる可能性があります。他方、MTP により、モデルが将来のトークンのより良い予測のためにその表現を事前に計画できるようになる可能性があります。図 3 は、MTP の実装を示しています。独立した出力ヘッドを使用して D 個の追加トークンを並行して予測する Gloeckle et al. (2024) とは異なり、追加のトークンを順次予測し、各予測深度で完全な因果チェーンを維持します。このセクションでは、MTP 実装の詳細を紹介します。

MTPモジュール

具体的には、MTP 実装では、D 個の追加トークンを予測するために D 個のシーケンシャルモジュールを使用します。k 番目の MTP モジュールは、共有埋め込み層 Emb(⋅)、共有出力ヘッド OutHead(⋅)、Transformer ブロック TRM_k(⋅)、およびプロジェクション行列 M_k ∈ ℝ^{d × 2d} で構成されます。i 番目の入力トークン t_i の場合、k 番目の予測深度で、最初に (k-1) 番目の深度での i 番目のトークンの表現 𝐡_i^{k-1} ∈ ℝ^d と (i+k) 番目のトークンの埋め込み Emb(t_{i+k}) ∈ ℝ^d を線形プロジェクションと組み合わせます。

\begin{align} 𝐡_i'^k = M_k [RMSNorm(𝐡_i^{k-1}); RMSNorm(Emb(t_{i+k}))] \end{align}

ここで、[⋅;⋅] は連結を示します。特に、k=1 の場合、𝐡_i^{k-1} はメインモデルによって与えられた表現を指します。各 MTP モジュールについて、その埋め込み層はメインモデルと共有されることに注意してください。結合された 𝐡_i'^k は、k 番目の深度で Transformer ブロックの入力として機能し、現在の深度 𝐡_i^k での出力表現を生成します。

\begin{align} 𝐡_{1:T-k}^k = TRM_k(𝐡_{1:T-k}'^k) \end{align}

ここで、T は入力シーケンス長を表し、i:j はスライス操作(左端と右端の両方を含む)を示します。最後に、𝐡_i^k を入力として使用すると、共有出力ヘッドは、k 番目の追加予測トークン P_{i+1+k}^k ∈ ℝ^V の確率分布を計算します。ここで、V は語彙サイズです。

\begin{align} P_{i+k+1}^k = OutHead(𝐡_i^k) \end{align}

出力ヘッド OutHead(⋅) は、表現をロジットに線形にマッピングし、その後 Softmax(⋅) 関数を適用して、k 番目の追加トークンの予測確率を計算します。また、各 MTP モジュールについて、その出力ヘッドはメインモデルと共有されます。予測の因果関係を維持するという私たちの原則は、EAGLEの原則と似ていますが、その主な目的は投機的デコードであるのに対し、私たちは MTP を使用してトレーニングを改善します。

MTPトレーニング目標

各予測深度について、クロスエントロピー損失 \mathcal{L}_{MTP}^k を計算します。

\begin{align} \mathcal{L}_{MTP}^k = CrossEntropy(P_{2+k:T+1}^k, t_{2+k:T+1}) = -\frac{1}{T} ∑_{i=2+k}^{T+1} log P_i^k[t_i] \end{align}

ここで、T は入力シーケンス長を示し、t_i は i 番目の位置での正解トークンを示し、P_i^k[t_i] は k 番目の MTP モジュールによって与えられた t_i の対応する予測確率を示します。最後に、すべての深度での MTP 損失の平均を計算し、重み付け係数 λ を掛けて、DeepSeek-V3 の追加のトレーニング目標となる全体的な MTP 損失 \mathcal{L}_{MTP} を取得します。

\begin{align} \mathcal{L}_{MTP} = \frac{λ}{D} ∑_{k=1}^D \mathcal{L}_{MTP}^k \end{align}

推論におけるMTP

MTP 戦略は主にメインモデルのパフォーマンスを向上させることを目的としているため、推論中には、MTP モジュールを直接破棄でき、メインモデルは独立して正常に機能します。さらに、これらの MTP モジュールを投機的デコードに再利用して、生成レイテンシをさらに向上させることもできます。


Infrastructuresへ続く
https://zenn.dev/tanegoma/articles/6f027ffe73c9b8

Discussion