😽

蒸留手法(Distillation)詳解 — DeepSeek のスペシャリスト→ジェネラリスト蒸留を中心に

に公開

1. 蒸留の位置付けと目的(Why)

蒸留(Knowledge Distillation, KD)は、**「複数(あるいは大きな)教師モデルが持つ知識を,より小さな/単一の生徒モデルに移す」**手法群です。DeepSeekの文脈では次が主目的になります。

  • スペシャリスト群が個別に学んだ**専門能力(数学的推論・ツール操作・CoT生成など)**を単一モデルへ統合する
  • 推論コスト・運用コストを抑えつつ、スペシャリストの性能を保持する
  • RLで得た方策(policy)や報酬最適化の成果を、標準的なLM損失で再現させる(安全に一般化させる)
  • thinking / non-thinking の二重モードを同一モデルで扱えるようにする

重要な点:DeepSeek のスペシャリストはRLで強化された教師であり、単純な教師出力の模倣だけでは足りない(報酬の構成要素・生成過程・方策的特徴が重要)ため、蒸留設計はより複雑になります。


2. 蒸留の主要パラダイム(What) — Taxonomy

蒸留は「何を模倣するか」「どのレベルで蒸留するか」で分類できます。ここでは代表的手法と DeepSeek での適用ニュアンスを列挙します。

  1. Logit / Soft-target Distillation(Hinton ら)

    • 教師の出力分布(softmax後の確率分布)を生徒が模倣。温度(temperature)を用いる。
    • DeepSeek: thinking 出力(長いCoT)と non-thinking 出力の両方で有効。
  2. Feature / Representation Distillation

    • 中間特徴(hidden states, attention maps, KV表現など)を一致させる。
    • DeepSeek: Indexer や attention の振る舞いを部分的に模倣させるのに有効。
  3. Sequence-level / CRF / Policy Distillation

    • 生成されるシーケンス全体を教師の方策に合わせる(例:最尤シーケンス、RL方策の行動確率)。
    • DeepSeek: RLで学んだ方策を模倣するため必須。
  4. Response-ranking / Contrastive Distillation

    • 教師が出した複数候補の中での相対順位を学習。
    • DeepSeek: 複数スペシャリストの提示する多様解を統合するときに有用。
  5. Multi-teacher / Ensemble Distillation

    • 複数の教師(スペシャリスト)を同時に蒸留。教師ごとに重み付けや専門性メタデータを用いる。
    • DeepSeekの主軸。
  6. Iterative / Progressive Distillation(Teacher Assistant等)

    • 教師から直接小さな生徒へ一気に落とすのではなく、中間的な教師(assistant)を置いて漸進的に蒸留。
    • DeepSeek: 複数段階での品質保全に有効、特に非常に大きなスペシャリスト群を統合するとき。
  7. Data-free / Synthetic Distillation

    • 実データが不十分な場合、教師から生成した疑似データで蒸留。
    • DeepSeek: スペシャリストが生成する thinking / non-thinking ペアをそのまま使う点で近い。

3. 蒸留で扱う情報の種類(何を写すか)

蒸留対象は大きく次の4つに分かれます。どれをどの程度取り入れるかが設計の鍵です。

  1. 出力分布(Logits / Probabilities)

    • 最も基本。温度を上げて教示(低温だとone-hotに近くなり情報が減る)。
  2. 中間表現(Hidden States)

    • 例:各レイヤのトークン埋め込み、attention weights、KV表現。
    • 理由:より深い抽象的知識(reasoning pathの痕跡)を移す。
  3. Policy情報(行動確率、価値推定)

    • RLで学んだ教師は方策(π(a|s))を持つ。これを模倣するには policy-distillation 用の損失を使う。
  4. シーケンス評価 / 質的情報

    • 教師が生成した CoT の品質、step-level の中間スコア、自己評価値など。
    • 例:教師が内部で計算した reward-to-go を生徒に教示する。

4. 蒸留に使う損失関数(How:数式と直感)

蒸留は通常、基礎的なLM損失(cross-entropy)蒸留損失の重み和で訓練します。ここでは使用される代表的な損失と式を示します。

4.1 Logit-level Distillation(温度付きKL)

教師の温度付き分布 ( p_T ) を生徒 ( q_\theta ) が模倣する:

[
p_T(y|x) = \text{softmax}\left(\frac{z_T(x)}{\tau}\right),\quad
q_\theta(y|x) = \text{softmax}\left(\frac{z_\theta(x)}{\tau}\right)
]

蒸留損失(KL):

[
\mathcal{L}{KD} = D{KL}\big(p_T | q_\theta\big) = \sum_y p_T(y|x) \log\frac{p_T(y|x)}{q_\theta(y|x)}
]

総合損失(LMクロスエントロピー (\mathcal{L}_{LM}) と合わせて):

[
\mathcal{L} = \alpha , \mathcal{L}{LM} + \beta , \tau^2 , \mathcal{L}{KD}
]

  • (\tau):温度(通常1〜8)
  • (\tau^2)は勾配尺度補正(Hintonらの元論文)

直感:温度を上げると教師分布が平滑化され、確率差の微妙な情報(類似性)が伝わりやすくなる。


4.2 Feature Distillation(中間表現整合)

中間層表現 ( h_T^{(l)} ) を生徒 ( h_\theta^{(l')} ) と合わせる。一般的に L2 損失や cosine 損失を使う。

例:

[
\mathcal{L}{feat} = \sum{(l,l')} \lambda_{l,l'} , | , \text{proj}( h_\theta^{(l')} ) - h_T^{(l)} , |_2^2
]

  • ( \text{proj}(\cdot) ) は次元整合のための線形射影
  • attention map を合わせる場合はクロスエントロピーやKLを使うこともある

直感:出力だけでなく「内部計算のやり方」も写すことで、より教師の思考過程を再現しやすくなる。


4.3 Sequence-level Distillation / Policy Distillation(RL 系)

教師の方策 (\pi_T)(確率分布)を生徒 (\pi_\theta) に模倣させる。方策レベルでの KL、あるいは行動価値(Q値)を用いた損失がある。

[
\mathcal{L}{policy} = \mathbb{E}{s \sim D}\big[ D_{KL}( \pi_T(\cdot|s) | \pi_\theta(\cdot|s) ) \big]
]

あるいは走査的に最尤生成列 (y^*) を教師の確率で強化し、生徒の生成確率を最大化させる形もある。


4.4 Contrastive / Ranking Loss

教師が複数候補を評価する場合、教師の評価スコアをランキング損失(margin-based)で再現:

[
\mathcal{L}{rank} = \sum{i,j} \max(0, m - s_T(i) + s_T(j)) \cdot \ell\big(s_\theta(i), s_\theta(j)\big)
]

ここで (s_T(i)) は教師のスコア、(\ell) は生徒の予測スコア差を反映する項。


4.5 総合損失の形

DeepSeek のシナリオでは複数の損失を組み合わせる必要がある。代表的な形:

[
\mathcal{L} = \alpha \mathcal{L}{LM} + \beta \mathcal{L}{KD_logit} + \gamma \mathcal{L}{feat} + \delta \mathcal{L}{policy} + \epsilon \mathcal{L}_{rank}
]

各係数はタスクや教師の性質により調整する。


5. スペシャリスト→ジェネラリスト蒸留の設計(DeepSeek向け実践設計)

DeepSeek の蒸留は単純な KD を上回る設計が必要です。ここでは実務的な設計手順を提示します。

5.1 入力データの準備

  1. スペシャリストが生成したデータを保存する(thinking / non-thinking 両モード)。各サンプルは次を含む:

    • プロンプト (x)
    • スペシャリスト出力シーケンス (y_T)
    • 各トークンの教師ロジット (z_T)(可能なら)
    • 内部スコア(reward-to-go, confidence)
    • attention maps / hidden states(必要かつ保存可能なら)
  2. 教師メタデータ

    • どのスペシャリスト(数学・agent等)か
    • モード(thinking/non-thinking)
    • 教師の報酬スコアや品質スコア

理由:後から教師ごとの重み付けやサンプリング戦略を変えるためにメタデータは必須。


5.2 サンプリング戦略(Multi-teacher の扱い)

複数教師がある場合、単純に混ぜると偏りが出る。次の戦略を推奨:

  • 比率サンプリング:教師ごとに重み (w_i) を与え、ミニバッチに混ぜる。
  • カリキュラムサンプリング:初期は簡単(non-thinking)のデータ重視、漸進的にthinkingデータ比率を増やす。
  • 重み適応:生徒の現在の弱点(validationでの低性能領域)に基づいて教師重みを動的に再配分。

5.3 損失の構成(DeepSeek 推奨)

  • 基本 LM 損失((\mathcal{L}_{LM})):生徒が正解を出すための主損失(クロスエントロピー)
  • Logit KD((\mathcal{L}_{KD})):温度付き KL、teacher logits を可能なら利用
  • Policy Distill((\mathcal{L}_{policy})):RL教師の方策を模倣(特に agent 専門教師に対して)
  • Feature Distill((\mathcal{L}_{feat})):重要レイヤの hidden states や attention maps を合わせる(計算許すなら)
  • Mode Consistency Loss:thinking と non-thinking を切り替えても出力が一貫することを促す損失(後述)

損失重みの初期値(例):

alpha (LM) = 1.0
beta  (KD_logit) = 0.5
gamma (feat) = 0.1
delta (policy) = 0.3

(タスク依存でチューニング)


5.4 Temperature とスケジューリング

  • 温度 τ:初期は高め(τ=4〜8)で教師分布を平滑化し、徐々に下げる(Curriculum KD)。
  • 蒸留比重スケジュール:初期は KD 成分を強めにして教師の振る舞いを素早く取り込む→中盤で LM と組み合わせ→終盤は LM を重視して微調整する。

5.5 中間表現(Feature)蒸留の実装方針

  • 選ぶべきレイヤ:出力近傍の数層(例:最終3層)+ attention-head-levelの集約マップ
  • 次元合わせ:教師と生徒のレイヤ深さや幅が異なる場合は線形射影(learned projection)で合わせる
  • 正則化:feature distill は過学習しやすいので dropout と L2正則を用いる

6. マルチモード(thinking / non-thinking)蒸留の工夫

DeepSeek の大きな挑戦は 同じモデルが「詳述するモード」と「即答するモード」 を両立することです。ここでは代表的手法を示します。

6.1 明示的モードラベル

  • プロンプトに special token を付与(例:<MODE=THINK> / <MODE=FAST>)して学習。
  • 生徒はモード条件付きモデルとして学習されるので、推論時にモード切替が可能。

6.2 モード専用の蒸留損失重み

  • thinking データに対しては feature distill(hidden state の一致)を強めに、non-thinking には logit KD を強めにする、など。

6.3 Dual-head / Mixture-of-Strategies 出力

  • モードに応じて 出力ヘッドを切り替える(lightweight head)。内部は共通だが最終層が分岐している。
  • 蒸留段階でスペシャリストの head を順に模倣する。

6.4 一貫性損失(Mode Consistency)

  • 同一プロンプトでモードを切り替えたとき、重要な事実(ファクトや結論)が矛盾しないようにする損失。
  • 例:teacher の thinking と non-thinking の出力を共に参照し、生徒が両方を整合的に再現するよう学習。

7. RLで学んだ教師(policy)からの蒸留(Policy Distillation)

RL教師の方策は行動選択の確率分布報酬構造を持つ。ここでの蒸留は単なるソフトターゲット模倣を超えます。

7.1 直接的な方策模倣

教師方策 (\pi_T) を模倣する KL. シンプルだが有効。

[
\mathcal{L}{policy} = \mathbb{E}{s \sim D}\big[ D_{KL}(\pi_T(\cdot|s) | \pi_\theta(\cdot|s)) \big]
]

7.2 Q-value / reward-to-go の蒸留

教師が計算した Q 値や reward-to-go を生徒が予測するように学習させる:

[
\mathcal{L}{Q} = \mathbb{E}{(s,a)} | Q_T(s,a) - Q_\theta(s,a) |_2^2
]

これにより生徒は「どの行動が長期的に良いか」を内的に学ぶ。

7.3 Offline RL と KD の統合

教師方策は通常オンラインRL中に得られるが、蒸留は オフラインデータ(教師が生成したトレース)で行うため、Offline RL の安全手法(importance sampling correction、behavior cloning with constraints)を導入することが安全。

7.4 安全性と方策崩壊の防止

方策の直接模倣だけだと「過学習」や「報酬ハック」が発生する。対策:

  • 教師方策のエントロピーを一定以上保つ(entropy regularization)
  • 複数の教師方策を混ぜる(ensemble distillation)
  • 生徒方策に対する trust region 制約(KL bound)

8. 実装の詳細・擬似コード

以下は DeepSeek スタイルの多教師・マルチモード蒸留の簡易擬似コードです(概念実装)。

# Pseudocode for multi-teacher distillation (DeepSeek-style)

for epoch in range(num_epochs):
    for batch in sampler(multi_teacher_dataset, batch_size):
        # batch contains samples from various teachers with metadata
        x = batch['prompt']
        teacher_logits = batch.get('teacher_logits')    # may be None
        teacher_hidden = batch.get('teacher_hidden')    # may be None
        teacher_mode = batch['mode']                    # THINK / FAST
        teacher_id = batch['teacher_id']

        # student forward (mode-conditioned)
        student_output = student_model.forward(x, mode=teacher_mode)
        student_logits = student_output.logits
        student_hidden = student_output.hidden_states

        # LM loss (cross-entropy) on teacher's sequence (or gold)
        lm_loss = cross_entropy(student_logits, batch['target_tokens'])

        # KD loss (logit-level), use temperature tau depending on mode
        tau = tau_for_mode(teacher_mode)
        if teacher_logits is not None:
            p_T = softmax(teacher_logits / tau)
            q_S = softmax(student_logits / tau)
            kd_loss = kl_div(p_T, q_S) * (tau**2)
        else:
            kd_loss = 0.0

        # Feature distillation (optional)
        if teacher_hidden is not None:
            proj_student = project(student_hidden)
            feat_loss = mse(proj_student, teacher_hidden)
        else:
            feat_loss = 0.0

        # Policy distillation (for RL teachers)
        if 'policy_probs' in batch:
            pi_T = batch['policy_probs']
            pi_S = student_output.policy_probs
            policy_loss = kl_div(pi_T, pi_S)
        else:
            policy_loss = 0.0

        # Mode consistency regularizer: same prompt different mode outputs consistent
        mode_consistency_loss = compute_mode_consistency(x, student_model)

        # Total loss
        total_loss = (alpha * lm_loss +
                      beta * kd_loss +
                      gamma * feat_loss +
                      delta * policy_loss +
                      eta * mode_consistency_loss)

        total_loss.backward()
        optimizer.step()

9. ハイパーパラメータ設計・スケジュールの例

以下は実務的な開始点(目安)。必ず validation/ablation でチューニングしてください。

  • バッチサイズ:512〜2048 tokens(大規模モデルでは大きめ)

  • 学習率:1e-5〜5e-5(AdamW)、LM精調に近い微小LR

  • weight decay:0.01

  • KD 温度 τ:

    • 初期:4〜8(thinking モードは高め)
    • 漸減スケジュールで最終は1
  • 損失重み(初期):

    • α (LM) = 1.0
    • β (KD_logit) = 0.5
    • γ (feature) = 0.1
    • δ (policy) = 0.3
    • η (mode_consistency) = 0.05
  • 学習スケジュール:

    • Warmup(1k steps)→一定 LR → cosine decay
  • チェックポイント:10k〜50k steps ごとに保存し、必ず evaluation rollback を実施


10. 評価指標とアブレーション実験(何を測るか)

蒸留の評価は単純なパープレキシティだけでは不十分。必須の評価セット:

  1. 専門タスクごとのベンチマーク(数学、agentタスク、CoT推論など)
  2. モードごとの応答品質(thinking vs non-thinking)
  3. 方策整合性テスト(RL教師が成功したタスクを生徒が達成するか)
  4. 一貫性チェック(同一プロンプトに対するモード切替で矛盾が生じないか)
  5. 回帰テスト(蒸留前の基礎モデルやスペシャリストの性能が劣化していないか)
  6. サンプル効率(同等性能に到達するためのステップ数/トークン数)
  7. 実行時レイテンシ・メモリ(運用面での改善)

アブレーション案:

  • Logit KD を外す vs 入れる
  • Feature distill の有無
  • Policy distill の有無
  • 温度スケジュールの有無
  • モードラベルの有無
  • 多教師サンプリング重みの変化

11. よくある問題と対策(落とし穴)

問題 A: 模倣だけで目標能力が伝わらない

  • 原因:教師が高次の戦略(長期のQ値など)を直接的に示していない
  • 対策:reward-to-go、Q値、内部評価スカラーを教師データに含める。policy-distill を導入。

問題 B: モード間で矛盾が生じる

  • 対策:mode consistency loss、強い条件付きトークン、mode-specific heads

問題 C: 蒸留でオーバーフィッティング

  • 特に feature distill で発生。
  • 対策:dropout強化、feature loss の重みを下げる、データ拡張

問題 D: 教師間で矛盾する出力(競合)

  • 対策:教師重み付け、コンフリクト解消のためのメタ学習(meta-weight learning)、あるいは ensemble scoring を利用して教師アンサンブルの合意点のみ学習させる

問題 E: 学習の不安定(RL由来)

  • 定期評価→巻き戻し・LR低下、trust-region constraint(KL制約)

12. 計算資源・ストレージ設計上の注意

  • 教師ロジットや hidden-states の保存は重い(特に1Tトークン相当のデータ)。圧縮(FP16→FP8)、サンプリング保存、必要レイヤのみ保存を検討。
  • バッチ処理:feature distill を含めると GPU メモリが急増するため、mixed-precision と gradient checkpointing を必須にする。
  • KVキャッシュの再利用:教師の Key/Value を将来の再調査のために保存する場合、保存フォーマットを標準化しておく(インデックス、トークン位置、対応付け)。
  • チェックポイント戦略:蒸留は長時間かかるため、incremental checkpoint + rolling evaluation が重要。

13. まとめ — 実務向けチェックリスト

短く実行に移すためのチェックリスト:

  • 各スペシャリストの出力(logits / hidden / policy info)を保存するフォーマットを定義したか
  • thinking / non-thinking のメタデータを保存しているか
  • Multi-teacher サンプリング戦略(重み・カリキュラム)を設計したか
  • 蒸留損失の初期重みと温度スケジュールを決めたか
  • feature distill のための projection 層を実装したか(教師・生徒次元合わせ)
  • policy distill 用の安全制約(KL bound, entropy reg)を導入したか
  • 評価スイート(専門タスク、モード整合性、一貫性)を用意したか
  • checkpoint と rollback 戦略を設計したか(定期評価基準を明確化)

付録:推奨実験プラン(短期〜中期)

  1. Baseline KD

    • まずは単一教師(best specialist)の logit-level KD から開始。評価基準を作る。
  2. Multi-teacher Mix

    • 複数教師を混ぜて同様にKD。教師重みは均等→弱点補完型へ変更。
  3. Feature Distill 導入

    • 最終層のみの feature distill を導入(低重み)。比較評価。
  4. Policy Distill 導入

    • RL教師の方策模倣を入れる。Offline RL の安全対策を検証。
  5. Mode-aware 蒸留

    • thinking/non-thinking を区別した学習。mode consistency を評価。
  6. Iterative 蒸留(Teacher Assistant)

    • 大きなスペシャリスト群を2〜3段階で統合する実験。

最後に(考察)

DeepSeek のように 多様なスペシャリスト群(しかも RL で強化された教師) を単一のジェネラリストに統合する蒸留は、典型的な KD を超えて「方策・内部表現・生成過程」まで写す高度な技術を要します。実務では、データの収集・メタデータ管理・安定化メカニズムが、損失関数やアーキテクチャ設計と同じくらい重要です。

Discussion