Open4
LLMのアーキテクチャ、事前学習周りの論文メモ
Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity
問題意識
- MoEは翻訳タスクで注目すべき成功をおさめたが複雑さや通信コスト、学習の不安定さなどで広い適用が妨げられている。本論文では従来よりも簡素化されたSwitch Transformerアーキテクチャを提案。
本論文のアプローチ
- 各トークンをルーティングして割り当てるExpertの数をこれまでの先行研究の2以上から1(Single Expert)へ変更(Switch layer)。ルーティングの計算コストを削減し、各Expertのバッチサイズを少なくとも半分にできる(Expert Capacityが上がる。Capacityを超えるとオーバーフローであふれたトークンは計算がスキップされるため重要)。
- Expertの計算の精度にbfloat16を使うことで通信コストを削減。ただ、bfloat16のような低精度でのルーター内のsoftmax演算は学習の不安定さを引き起こすため、ルーター内はfloat32で計算。これによって学習が安定。
- 適切な初期化が学習の成功には重要。デフォルトのTransformerのscale hyper-parameterを10分の1にするとSwitch Transformerの学習が安定した。これはモデルサイズを変えても当てはまった。
- MoEは過学習しやすくFine-tuningが難しいという課題があるが、シンプルにドロップアウトで対応。ただ、単純に全てのレイヤのドロップアウトレートを大きくしてもパフォーマンスが下がったため、Expertレイヤ以外は小さなレート、Expertレイヤには大きなレートを適用することでパフォーマンスが向上した。
評価とAblation Study
- 固定ステップ数での評価では、Expert数を増やすほどTest lossは小さくなり、サンプル効率が良くなった(少ないステップ数で収束する)。T5-baseと64 expertのSwitch Transformer-baseでは後者の方が同じ精度を達成するのにステップ数単位で7.5倍、時間で7倍高速だった。T5-largeと比較しても2.5倍高速だった。
- ダウンストリームタスクでもSwitch Transformer-base, largeはT5のbase, largeにほとんどのタスクで精度が上回る結果となった。
- 蒸留も実験。Non-Expertレイヤの重みで初期化したStudent-modelとTeacherのSwitch Transformer-base + ダウンストリームタスクの損失で蒸留したらT5-baseよりも精度が良かった(純粋な初期化からスタートするとT5-baseとほとんど変わらない結果だった)。
- 多言語モデルのmT5をSwitch Transformerにして比較。101の言語全てでmT5と比較して改善が見られた。
- Switch Transformer-XXL(395B) vs T5-XXL(11B)の比較では、前者の250kステップでの精度が後者の500kステップの精度を上回っていた。
- モデルのスケーリングはExpertパラレル(Expertごとにデバイスを分ける)だけで1.6Tパラメータ(Expert数は2048)まで大きくできた。大規模なスパースモデルは学習が不安定になることがあるが、1.6Tでは学習の不安定さはなかった。一方で、Switch Transformer-XXLでは不安定になることがあった。
MoE-Mamba: Efficient Selective State Space Models with Mixture of Experts
問題意識
- SSM(State Space Model)ベースのアーキテクチャをスケーリングするためにMoEレイヤを導入。
本論文のアプローチ
- Mambaレイヤを一つおきにMoEレイヤ(Single ExpertにルーティングするSwitchレイヤ)へ置き換えた。
評価とAblation Study
- 素のTransformer, Transformer-MoE, Mamba, Mamba-MLP(N Expert=1), MoE-Mambaで比較実験。English C4データセット(6.5Bトークン)で100kステップ学習(ハイパラは素のMamba向けに使われたものを全てのモデルに適用)。MoE-Mambaが最もTest lossが下がり、素のMambaよりも同じ精度を達成するために必要なステップ数が2.2倍短縮され、TransformerやTransformer-MoEよりも精度が良かった。一方で、Mamba-MLPは素のMambaに比べわずかに精度が劣る結果になった。
- Expert数を1, 4, 8, 16, 32へ変更して実験するとExpert数が大きいほど精度が良くなった。
- Mambaレイヤ内のOutput projectionやConv ProjectionをMoEに置き換えたが素のMambaと精度は変わらなかった。MambaとMoEを並列にしたが、素のMambaよりも精度が悪化した。
- Future WorkとLimitationとして
- この論文ではパラメータ数が1Bに満たない小さなモデルで実験されていること
- Mamba vs Mamba-MLPではMambaの方が優れていたため、MambaレイヤをMoE化してみるのはおもしろそうとのこと
- MoEにSwitchレイヤを適用したが、Mambaに対して最良のデザインなのかはわからないので色々と試す必要がある。
Mixtral of Experts
問題意識
- Mistral 7BをSparse Mixture of Experts (SMoE) でスケーリングさせたMixtral 8×7Bを提案。
本論文のアプローチ
- MistralのFFNを8 ExpertsのMoEレイヤへ置き換え。 ルーターは各トークンを2つのExpertへルーティングし、2つのExpertの出力は加重合計される。モデルのパラメータ数は47Bだが、推論時にアクティブなパラメータは13Bとなり、パラメータ数の割に推論は高速になる(一方でGPUメモリ上には47Bのパラメータがロードされることには注意が必要)。
- MoEレイヤはMegablocksのカーネルで大規模な疎行列の掛け算へとキャストされるのでGPU上で効率的に計算できる。また、モデルパラレルとエキスパートパラレルで複数のGPU上で分散学習できる。
- Mistral 7Bと同じくコンテキスト長は32kだが、事前学習に使用する多言語データの割合を増やした。
- InstructバージョンはSFTとDPOを実行した。
評価とAblation Study
- Llama2 7B/13B/70Bとベンチマークで比較したところ、ほとんどのタスクでLlama2 70Bを上回っており、特に数学やコード生成では優れた結果となった。
- 多言語タスクでの比較でも、Llama1 33BとLlama2 70Bを上回る結果となった。
- Llama2 70BとGPT3.5との比較でも7つのタスクの内、4つで最も優れた結果となった。
- 長いコンテキストでの性能をpasskey検索タスクで評価したが、シーケンスの長さや場所に限らず100%の精度で検索できた。
- モデルのバイアスの評価でもLlama2 70Bよりもバイアスが小さい結果となった。
- Instructバージョンは、LMSysリーダーボードで、GPT3.5-TurboやGemini Pro, Claude-2.1といった商用モデルよりも優れたパフォーマンスであった。
Nemotron-4 340B Technical Report
問題意識
- 340Bのモデルファミリー(Base, Reward, Instruct)をオープンソースで公開。また、モデルの出力を使った合成データ生成なども可能なライセンス形態。
- 再現性を確保するため、学習、推論、合成データ生成のパイプラインを公開。
本論文のアプローチ
- Baseモデル
- データ
- 学習に使用したデータは9Tトークン
- 70%が英語、15%が多言語(53言語)、15%がプログラミングコード(43言語)
- 学習は2ステージ(これは効果的であったと言及されている)
- 事前学習に8Tトークン
- 継続学習に1Tトークン(高品質データのサンプリング重み付け、少数の質問応答スタイルサンプルの導入、モデルの精度が低いデータソースの重み付けなど)
- モデル
- Nemotron-4-15B-Baseに類似したアーキテクチャ
- Decoder-only Transformer アーキテクチャ 96レイヤ、Seq len=4096, Vocab size=256,000
- RoPE
- squared ReLU
- GQA
- 計算リソース
- DGX H100 768ノード、6,144 GPUを学習に使用
- TP=8, PP=12, DP=64
- 推論はBF16ではDGX H100 * 2ノード、FP8ではDGX H100 * 1ノードで実行可能
- DGX H100 768ノード、6,144 GPUを学習に使用
- データ
- Rawardモデル
- データ
- 10Kの人間の嗜好データで構成されるHelpsteer2データセットを使用
- 応答に関する複数の属性スコア(Helpfulness, Correctness, Coherence, Complexity, Verbosity)が付与されている
- モデル
- Baseモデルの最後のSoftmaxレイヤをlinearに置き換え属性スコアを出力できるように変更
- データ
- Instructモデル
- プロセス
- コードSFT
- コーディング能力を高めるにはかなりのデータが必要だった。合成データも組み合わせて、厳選された800kサンプルのデータを利用。
- 一般的なSFT
- 様々なタスクで混合された200kサンプルのデータを利用。忘却リスクを抑えるためにコードSFTのデータを2%利用。
- Direct Preference Optimization (DPO)
- プロンプト、選択された応答、拒否された応答の3つ組みで選択された応答を区別できるよう学習。
- さまざまなタスクを含む160kサンプルのデータセットを合成して利用。
- 選択された応答のSFT損失も追加することで過剰適合を抑制。
- Reward-aware Preference Optimization (RPO)
- DPOでは、2つ(選択、拒否)の応答の順序のみを使うが、実際は拒否されたものでもわずかに劣るサンプル、はるかに劣るサンプルなどがある。これがDPOの過剰適合につながる。RPOはRewardモデルの報酬を使った新しい損失関数を定義。
- それほど厳しくない品質フィルタリングを通過した300kサンプルの合成データセットを利用。
- DPO同様に選択された応答のSFT損失も追加。
- コードSFT
- プロセス
評価とAblation Study
- Baseモデル
- LM-Evaluation HarnessでMixtral 8×22B, Llama-3 70B, Qwen-2 72Bと比較を行ったところ、ほとんどの指標で最も結果が良かった。
- Rewardモデル
- RewardBenchで全体としてSOTA。商用モデル(GPT-4, GPT-40, Gemini 1.5 Proなど)とも遜色ない結果であった。
- Instructモデル
- 自動ベンチマーク
- 商用モデルとも遜色ない結果。コードSFTはコーディング能力の向上に効果的。
- 人間によるベンチマーク
- GPT-4-1106-previewと同等かそれ以上、マルチターンチャットで優れた結果。
- 安全性評価
- Llama-3-70B-Instructとの比較で、安全でない応答の割合が少なかった。
- 自動ベンチマーク