人間の好みを学習するDPOを理解してみよう!
この記事は、LLM・LLM活用 Advent Calendar 2024 の 11 日目の記事になります。
今回は、LLM(大規模言語モデル)を人間の好みに合わせて調整したいときに使うDPO(Direct Preference Optimization)についてお話しします。
DPOって何?簡単に言うと...
人間の「いいね!」「それはちょっと...」という評価を直接AIに学習させる方法です。例えば:
# 従来のAIの返答
user_input = "今日は調子が悪いです"
ai_response = "それは残念ですね。"
# DPOで改善後のAIの返答
user_input = "今日は調子が悪いです"
ai_response = "つらいですね。具体的にどんな症状がありますか?ゆっくり休むことをおすすめします。"
なぜDPOが必要なの?
従来のAIには以下のような課題がありました:
- 機械的な返答が多い
- 時として不適切な発言をする
- 文脈を考慮しない返答をする
DPOの仕組み(超シンプル版)
Step 1: 人間の好みデータを集める
# 好みデータの例
preferences = {
"良い回答": "体調管理は大切ですね。ゆっくり休んでください。",
"改善が必要な回答": "病院に行けば?",
}
Step 2: AIに学習させる
- 良い回答の確率を上げる
- 良くない回答の確率を下げる
- これだけ!
従来手法(RLHF)との違い
LLMの振る舞いを制御する従来手法のRLHF(Reinforcement Learning from Human Feedback)には、以下のような課題がありました:
- 3段階の複雑なプロセスが必要
- 教師ありファインチューニング
- 報酬モデルの学習
- 強化学習による最適化
- 多大な計算資源と時間が必要
- 報酬モデルの学習が不安定になりやすい
- 期待通りの結果が得られないことがある
これらの課題を解決するため、よりシンプルで効率的な手法としてDPOが登場しました。DPOは報酬モデルを使用せず、直接人間の好みを学習できる画期的な手法です。
シンプルな表で比較してみましょう:
項目 | RLHF | DPO |
---|---|---|
実装の複雑さ | 複雑 | シンプル |
計算コスト | 高い | 低い |
学習の安定性 | やや不安定 | 安定 |
セットアップ時間 | 長い | 短い |
DPOに必要なデータセットは?
DPOでAIを賢く育てるには、「いい回答」と「よくない回答」のセットが必要です。
データセットの基本構造
このデータセットは、以下の3つの要素で構成されます:
- 質問文(prompt):AIに与える質問やお題
- いい回答(chosen):人間が「これはいいね!」と評価した回答
- よくない回答(rejected):人間が「これはちょっと...」と評価した回答
具体例で見てみよう
# データセットの例
dataset = {
"質問(prompt)": "体調が悪いときはどうすればいいですか?",
"いい回答(chosen)": "まずは十分な休息を取り、必要に応じて医療機関に相談することをお勧めします。",
"よくない回答(rejected)": "頑張って仕事を続けましょう。"
}
実際のデータセット例
現在、以下のような公開データセットがあります:
このように、DPOでは人間の判断基準をAIに教えるためのデータが重要です。質の高いデータセットを用意することで、より賢いAIを育てることができます。
DPOの実装例
TRLライブラリで簡単実装
Hugging Faceが提供するTRLライブラリを使えば、数行のコードでDPOが実装できます。
from trl import DPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
# モデルとトークナイザーの準備
model = AutoModelForCausalLM.from_pretrained("llama2-7b-japanese")
tokenizer = AutoTokenizer.from_pretrained("llama2-7b-japanese")
# トレーニングデータの例
training_data = {
"prompt": ["今日の調子はどうですか?"],
"chosen": ["お気遣いありがとうございます。元気です。"],
"rejected": ["別に。"]
}
# DPOトレーナーの設定
trainer = DPOTrainer(
model=model,
tokenizer=tokenizer,
beta=0.1, # モデルの変化量を調整(小さいほど控えめな変化)
)
# 学習実行
trainer.train()
# モデルの保存
trainer.save_model("my_dpo_model")
実装時のポイント
項目 | 説明 |
---|---|
モデル選択 | 日本語対応のベースモデルを選ぶ |
データ準備 | 良い回答と悪い回答のペアを用意 |
学習率 | 小さめの値(例:1e-5)から始める |
バッチサイズ | GPUメモリに応じて調整(4-8程度) |
このように、DPOの実装は思ったより簡単です。基本的には「良い応答」の確率を上げ、「悪い応答」の確率を下げる、というシンプルな考え方で実装できます。
まとめ
- DPOは人間の好みを直接AIに反映できる
- 従来手法(RLHF)より実装が簡単
- データの質が成功の鍵
DPOについては今更感がありますが、整理するために良い機会になりました。DPO以外にもKTOやCPOといった派生手法もあります。
関連情報
Discussion