🥰

人間の好みを学習するDPOを理解してみよう!

2024/12/11に公開

この記事は、LLM・LLM活用 Advent Calendar 2024 の 11 日目の記事になります。

https://qiita.com/advent-calendar/2024/large-language-model

今回は、LLM(大規模言語モデル)を人間の好みに合わせて調整したいときに使うDPO(Direct Preference Optimization)についてお話しします。

DPOって何?簡単に言うと...

人間の「いいね!」「それはちょっと...」という評価を直接AIに学習させる方法です。例えば:

# 従来のAIの返答
user_input = "今日は調子が悪いです"
ai_response = "それは残念ですね。"

# DPOで改善後のAIの返答
user_input = "今日は調子が悪いです"
ai_response = "つらいですね。具体的にどんな症状がありますか?ゆっくり休むことをおすすめします。"

なぜDPOが必要なの?

従来のAIには以下のような課題がありました:

  • 機械的な返答が多い
  • 時として不適切な発言をする
  • 文脈を考慮しない返答をする

DPOの仕組み(超シンプル版)

Step 1: 人間の好みデータを集める

# 好みデータの例
preferences = {
    "良い回答": "体調管理は大切ですね。ゆっくり休んでください。",
    "改善が必要な回答": "病院に行けば?",
}

Step 2: AIに学習させる

  • 良い回答の確率を上げる
  • 良くない回答の確率を下げる
  • これだけ!

従来手法(RLHF)との違い

LLMの振る舞いを制御する従来手法のRLHF(Reinforcement Learning from Human Feedback)には、以下のような課題がありました:

  • 3段階の複雑なプロセスが必要
    1. 教師ありファインチューニング
    2. 報酬モデルの学習
    3. 強化学習による最適化
  • 多大な計算資源と時間が必要
  • 報酬モデルの学習が不安定になりやすい
  • 期待通りの結果が得られないことがある

これらの課題を解決するため、よりシンプルで効率的な手法としてDPOが登場しました。DPOは報酬モデルを使用せず、直接人間の好みを学習できる画期的な手法です。

シンプルな表で比較してみましょう:

項目 RLHF DPO
実装の複雑さ 複雑 シンプル
計算コスト 高い 低い
学習の安定性 やや不安定 安定
セットアップ時間 長い 短い

DPOに必要なデータセットは?

DPOでAIを賢く育てるには、「いい回答」と「よくない回答」のセットが必要です。

データセットの基本構造

このデータセットは、以下の3つの要素で構成されます:

  1. 質問文(prompt):AIに与える質問やお題
  2. いい回答(chosen):人間が「これはいいね!」と評価した回答
  3. よくない回答(rejected):人間が「これはちょっと...」と評価した回答

具体例で見てみよう

# データセットの例
dataset = {
    "質問(prompt)": "体調が悪いときはどうすればいいですか?",
    "いい回答(chosen)": "まずは十分な休息を取り、必要に応じて医療機関に相談することをお勧めします。",
    "よくない回答(rejected)": "頑張って仕事を続けましょう。"
}

実際のデータセット例

現在、以下のような公開データセットがあります:

https://huggingface.co/datasets/cyberagent/chatbot-arena-ja-calm2-7b-chat-experimental

https://huggingface.co/datasets/llm-jp/hh-rlhf-12k-ja?row=10

https://huggingface.co/datasets/weblab-GENIAC/aya-ja-evol-instruct-calm3-dpo-masked

このように、DPOでは人間の判断基準をAIに教えるためのデータが重要です。質の高いデータセットを用意することで、より賢いAIを育てることができます。

DPOの実装例

TRLライブラリで簡単実装

Hugging Faceが提供するTRLライブラリを使えば、数行のコードでDPOが実装できます。

from trl import DPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer

# モデルとトークナイザーの準備
model = AutoModelForCausalLM.from_pretrained("llama2-7b-japanese")
tokenizer = AutoTokenizer.from_pretrained("llama2-7b-japanese")

# トレーニングデータの例
training_data = {
    "prompt": ["今日の調子はどうですか?"],
    "chosen": ["お気遣いありがとうございます。元気です。"],
    "rejected": ["別に。"]
}

# DPOトレーナーの設定
trainer = DPOTrainer(
    model=model,
    tokenizer=tokenizer,
    beta=0.1,  # モデルの変化量を調整(小さいほど控えめな変化)
)

# 学習実行
trainer.train()

# モデルの保存
trainer.save_model("my_dpo_model")

実装時のポイント

項目 説明
モデル選択 日本語対応のベースモデルを選ぶ
データ準備 良い回答と悪い回答のペアを用意
学習率 小さめの値(例:1e-5)から始める
バッチサイズ GPUメモリに応じて調整(4-8程度)

このように、DPOの実装は思ったより簡単です。基本的には「良い応答」の確率を上げ、「悪い応答」の確率を下げる、というシンプルな考え方で実装できます。

まとめ

  • DPOは人間の好みを直接AIに反映できる
  • 従来手法(RLHF)より実装が簡単
  • データの質が成功の鍵

DPOについては今更感がありますが、整理するために良い機会になりました。DPO以外にもKTOやCPOといった派生手法もあります。

関連情報

https://zenn.dev/matsuolab/articles/d76e5faaf4e18b#▶︎データ作成

https://qiita.com/jovyan/items/6767c9fd944a636fdf88

https://note.com/npaka/n/n23576a1211a0

https://arxiv.org/abs/2404.14723

Discussion