📝

New LLM Pre-training and Post-training Paradigms

2024/08/18に公開

A Look at How Modern LLMs Are Trained

https://magazine.sebastianraschka.com/p/new-llm-pre-training-and-post-training

要約

この記事は 4 つの最新の大規模言語モデル (LLM) の事前学習と事後学習のパラダイムについて詳しく説明している。主なポイントは以下の通り :

  1. Alibaba's Qwen 2 :

    • 多段階の事前学習プロセスを採用
    • 事後学習で DPO と RLHF を組み合わせて使用
  2. Apple's Foundation Language Models (AFM):

    • 3 段階の事前学習プロセス
    • 事後学習で複数の選好チューニングアルゴリズムを組み合わせて使用
  3. Google's Gemma 2 :

    • 知識蒸留を事前学習と事後学習の両方で重視
    • 事後学習で RLHF とモデル平均化を使用
  4. Meta AI's Llama 3.1 :

    • 大規模な 15.6兆トークンのデータセットで訓練
    • 事後学習で SFT、DPO、リジェクションサンプリングを反復的に適用

全体的な傾向として :
- 多段階の事前学習プロセスが一般的
- データ品質の重視
- リジェクションサンプリングが事後学習で一般的
- DPO と RLHF の使用に関してはまだコンセンサスがない

これらのモデルは異なるアプローチを取っているが MMUL ベンチマークなどでは似たようなパフォーマンスを示している

Alibaba's Qwen 2

  1. モデルバリエーション:

    • 4 つの通常(密な)LLM モデル:0.5B、1.5B、7B、72B パラメーター
    • 1 つの Mixture-of-Experts モデル:57B パラメータ(14B パラメーターが同時にアクティブ)
  2. 特徴:

    • 30 言語での優れた多言語能力
    • 151,642 トークンの大規模な語彙サイズ(他のモデルと比較して非常に大きい)
  3. 事前学習:

    • 1.5B、7B、72B モデルは 7兆トークンで訓練
    • 0.5B モデルは 12兆トークンで訓練(ただし他のモデルでは追加コストに見合う改善が見られなかった)
    • データフィルタリングパイプライン改善とデータ混合強化
    • 以前の Qwen モデルを使って追加の事前学習データを合成
    • マルチタスク指示データを事前学習に統合
  4. 事前学習の 2 段階プロセス:

    • 通常の事前学習
    • 長文脈学習(コンテキスト長を 4,096 から 32,768 トークンに拡張)
  5. 事後学習:

    • 教師あり指示微調整 (SFT) : 500,000 例で 2 エポック
    • 直接選好最適化 (DPO) を使用した人間の選好との整合
    • 2 段階のアラインメントフェーズ:
      a) 既存データセットでの DPO(オフラインステージ)
      b) 報酬モデルを使用した選好ペアの形成(オンラインステージ)
  6. データセット構築:

    • 既存のコーパスと人間のラベリングを組み合わせ
    • 人工的にアノテーションされたデータ合成
    • LLM を使って高品質な文学データ向けの指示-応答ペアを生成
  7. パフォーマンス:

    • MMUL ベンチマークで競争力のあるスコアを達成

Qwen 2 の特筆すべき点は事前学習と事後学習の両方で合成データを活用していること、データセットのフィルタリングに重点を置いていること、そして多段階の学習プロセスを採用していること。これらのアプローチにより効率的かつ高性能なモデルの開発を実現している

Apple's Foundation Language Models (AFM)

  1. モデルバリエーション:

    • 3B パラメーターのオンデバイスモデル(電話、タブレット、ラップトップ向け)
    • 3B パラメーターのより高性能なサーバーモデル
  2. 用途:

    • チャット、数学、コーディングタスク向けに開発(ただしコーディング特有の訓練については詳細な記述が無い)
  3. 事前学習の特徴:

    • 公開データと出版社からライセンス供与されたデータを使用
    • ウェブサイトの robots.txt ファイルを尊重
    • ベンチマークデータとの汚染除去を実施
    • 品質を量よりも重視
  4. 語彙サイズ:

    • デバイスモデル : 49k トークン
    • サーバーモデル : 100k トークン
  5. 3 段階の事前学習プロセス:

    • コア(通常)事前学習:

      • サーバーモデルは 6.3兆トークンで訓練
      • オンデバイスモデルは 6.4B パラメーターモデルから蒸留・剪定
    • 継続事前学習:

      • ウェブクロール(低品質)データの重みを下げ数学とコードの重みを上げる
      • 1 兆トークンのデータセットで実施
      • コンテキスト長を 4,096 から 8,192 トークンに拡張
    • コンテキスト拡張:

      • 100 億トークン(第 2 段階の 10%)で実施
      • コンテキスト長を 32,768 トークンまで拡張
      • 合成長文脈 Q&A データで強化
  6. 事後学習:

    • 人間がアノテーションしたデータと合成データを使用
    • 2 段階プロセス : 教師あり指示微調整 (SFT) と複数ラウンドの人間のフィードバックによる強化学習 (RLHF)
    • 2 つの新しい RLHF アルゴリズムを導入:
      • 教師委員会によるリジェクションサンプリング微調整 (iTeC)
      • ミラー降下方策最適化による RLHF
  7. 特筆すべき点:

    • 複雑な委員会ベースのアプローチ(SFT、DPO、IPO、オンライン RL を組み合わせ)
    • 複数の選好チューニングアルゴリズムを使用
    • 3B という比較的小さなモデルサイズを活かした多様な技術の適用
  8. データ品質とミックス:

    • データの品質を量より重視
    • 事前に決められたデータ比率に頼らず実験を通じて最適なデータミックスを調整

AFM の開発アプローチは非常に包括的で多段階の学習プロセスと複数のアルゴリズムの組み合わせが特徴。これはモデルが何百万もの、場合によっては何十億ものデバイスにデプロイされるという高い要求に応えるためのものと考えられる

Google's Gemma 2

  1. モデルサイズ

    • 2B、9B、27B パラメーターの 3 種類を提供
    • 大規模なデータセットではなく比較的小規模で効率的な LLM 開発に焦点
  2. 特徴

    • 256k トークンの大規模な語彙サイズ(Llama 2 の 32k、Llama 3 の 128k と較べて大きい)
    • Mistral の初期モデルと同様なスライディングウィンドウ注意機構を採用(メモリーコスト削減のため)
  3. 事前学習

    • 27B モデル : 13 兆トークンで訓練
    • 9B モデル : 8 兆トークンで訓練
    • 2B モデル : 2 兆トークンで訓練
    • データ品質の維持に注力し知識蒸留などの代替手法で改善を達成
    • 27B モデルはゼロから訓練、小規模モデルは知識蒸留を使用
    • Apple の AFM と同様にデータミックスを最適化
  4. 知識蒸留を活用

    • 事前学習と事後学習の両方で知識蒸留を重視
    • 小規模モデルの性能向上に貢献
  5. 事後学習プロセス

    • 教師あり微調整 (SFT) と人間のフィードバックによる強化学習 (RLHF) を実施
    • 英語のみのプロンプトペアを使用 (人間が生成したものと合成生成したものを混合)
    • SFT フェースでも知識蒸留を適用
    • RLHF では報酬モデルがポリシー (ターゲット) モデルの 10 倍大きいのが特徴
  6. RLHF の特徴

    • WARP (Weight-Averaged Reward Policy) 法を使用してポリシーモデルを平均化
    • WARM (Weight-Averaged Reward Models) の後継手法
  7. 開発アプローチ

    • 多段階の学習プロセスは詳細に記述されていない (または使用されていない可能性がある)
    • 知識蒸留に重点を置いた比較的シンプルなアプローチ

Gemma 2 は大規模なデータセットに依存せずに効率的な LLM を開発する方法を探求している点が特徴的。知識蒸留を中心とした技術の活用により比較的小規模なモデルでも高いパフォーマンスを実現している

Meta AI's Llama 3.1

  1. モデルサイズ

    • 405B パラメーターの大規模モデルを新たにリリース
    • 既存の 8B と 70B パラメーターモデルを更新し MMLU パフォーマンスを向上
  2. アーキテクチャー

    • グループクエリー注意を採用
    • スライディングウィンドウ注意や Mixture-of-Experts アプローチは採用せず
    • 従来的なアーキテクチャーを維持し事前学習と事後学習に注力
  3. ライセンス

    • オープンウェイトで提供
    • 合成データ生成や知識蒸留による他モデルの改善が可能になるようライセンスを更新
  4. 事前学習

    • 15.6 兆トークンの巨大データセットで訓練
    • 少なくとも 8 言語をサポート
    • 128,000 トークンの語彙サイズ(OpenAI の tiktoken トークナイザーを使用)
  5. データ品質管理

    • ヒューリスティックベースのフィルタリングとモデルベースの品質フィルタリングを併用
    • Meta AI の fastText や RoBERTa ベースの分類器を活用
  6. 3 段階の事前学習プロセス

    • 標準(初期)事前学習
      • 15.6 兆トークンを使用、8k コンテキストウィンドウ
      • バッチサイズと系列長を段階的に増加
    • コンテキスト拡張のための継続事前学習
      • コンテキスト長を 8,000 から 128,000 トークンまで 6 段階で増加
      • 全データセットの約 5%(8,000 億トークン)を使用
    • 高品質データでのアニーリング
      • 小規模だが高品質なデータミックスで訓練
      • ベンチマークデータセットのパフォーマンス向上に寄与
  7. 事後学習

    • 教師あり微調整 (SFT)、リジェクションサンプリング、直接選好最適化 (DPO) を使用
    • SFT と DPO を複数ラウンド反復
    • 人間が生成したデータと合成データの両方を活用
    • 報酬モデルをリジェクションサンプリングに使用
    • SFT、DPO、報酬モデルすべてにモデル平均化技術を適用
  8. 特徴

    • 大規模データセットでの訓練
    • 3 段階の事前学習プロセス
    • 知識蒸留を使用せず、より直接的なモデル開発アプローチを採用
    • DPO を使用し複雑な強化学習戦略ではなくシンプルな方法を選択

Llama 3.1 は大規模なデータセットと多段階の学習プロセスを特徴とし、シンプルかつ効果的な手法を組み合わせて高性能を実現している

リジェクションサンプリング

  1. 概要

    • 直接サンプリングが困難な確率分布からサンプルを生成する手法
    • 提案分布からサンプルを生成し一部を「拒否」することで目標分布を近似
  2. プロセス

    • 提案分布選択 : 目標分布を包含しサンプリングが容易な分布を選ぶ
    • サンプル生成 : 提案分布からサンプルを生成
    • 受理確率の計算 : 目標分布と提案分布の比率に基づいて決定
    • サンプルの受理または拒否 : 計算した確率に基づいて判断
  3. 数学的基礎

    • 受理確率 = min (1, (目標分布の確率密度) / (提案分布の確率密度) * 定数)
    • 定数は提案分布が目標分布を常に上回るように選択
  4. LLM における応用

    • 生成された複数の応答候補から最適なものを選択
    • 報酬モデルを使用して応答の品質を評価し受理確率を決定
  5. 利点

    • 複雑な分布からの直接サンプリングを回避
    • 品質の高いサンプルを選択的に生成可能
    • モデルの出力を制御し望ましい特性を持つ応答を促進
  6. 課題

    • 効率的な提案分布の設計
    • 計算コストと生成品質のトレードオフ
    • 高次元空間で受理率低下
  7. LLM での具体的な実装例

    • 複数の候補応答を生成(例 : ビーム探索を使用)
    • 各候補に対して報酬モデルでスコアを計算
    • スコアに基づいて受理確率を決定し最終的な応答を選択
  8. 他の技術との組み合わせ

    • DPO (Direct Preference Optimization) との統合
    • SFT (Supervised Fine-Tuning) 後の品質向上手段として使用
    • RLHF (Reinforcement Learning from Human Feedback) パイプラインの一部として活用

リジェクションサンプリングは LLM の応答生成プロセスを制御し、出力の品質を向上させるための効果的な手法として広く採用されている

モデル平均化

  1. 概要

    • 複数のモデルのパラメーターを平均化して新しいモデルを生成する手法
    • アンサンブル学習の一種だが単一のモデルを生成する点が特徴的
  2. 目的

    • モデルの汎化性能向上
    • 過学習軽減
    • 予測安定性向上
  3. 手法

    • 重み平均化 (Weight Averaging)

      • 複数モデルの対応する重みを直接平均化
      • 例 : モデル A と B の重みを 0.5 ずつで平均化
    • 指数移動平均 (Exponential Moving Average, EMA)

      • 新しい重みに対して徐々に減衰する重みを適用
      • 最近のモデルにより高い重みを与える
    • 確率的重み平均化 (Stochastic Weight Averaging, SWA)

      • 学習の後半で複数のチェックポイントの重みを平均化
      • 最適解の周りでのより良い探索を可能にする
  4. 利点

    • 計算コストの増加を抑えつつ性能を向上
    • アンサンブル学習と比較してメモリー効率が良い
    • 異なる局所最適解の特徴を組み合わせることで汎化性能を向上
  5. 応用例

    • Llama 3.1 で SFT、DPO、報酬モデルを平均化
    • WARP (Weight-Averaged Reward Policy) による RLHF 改善
    • 画像認識や自然言語処理タスクでの性能向上
  6. 課題

    • 適切な平均化のタイミングとモデル選択
    • 非線形性の強いモデルでの効果の限界
    • ドメイン特化型の平均化戦略の設計
  7. 実装上の考慮事項

    • バッチ正規化層の統計情報の適切な処理
    • 学習率スケジューリングとの統合
    • モデルアーキテクチャの互換性の確保

モデル平均化技術は単一のモデルでありながらアンサンブル学習の利点を部分的に享受できる効果的な手法として様々な深層学習タスクで活用されている

Discussion