LLMを強化学習: 進化が早すぎる!RLHFライブラリtrlの変わること変わらないこと
この記事はCyberAgent AI Lab Advent Calendar 202420日目の記事です。
はじめに
注意
このブログ執筆時点のtrlのversionはv0.12.2です.
自己紹介
初めまして,サイバーエージェントAI Lab強化学習チームでリサーチエンジニアをしている坂本です.弊チームでは現在,LLMを人間の評価から強化学習するRLHF(Reinforcement Learning with Human Feedback [Ouyang+, 2022])の研究に取り組んでいます.私はリサーチエンジニアとして,主に実装や実験環境の整備を担当しています.今回のブログでは,毎日10本近くの新しい論文が登場するこの分野で,私たちを助けてくれる,強力だけど癖のあるライブラリ「trl」 の活用方法をご紹介したいと思います.
この記事ではRLHFの手法についての詳しい説明は行いません
今,強化学習がアツい
12月20日現在,世界中の人たちが注目する「12 Days of OpenAI」では,Day2にて「OpenAI's Reinforcement Fine-Tuning Research Program」が発表されました.これは,LLMを対象にした強化学習の応用と研究をさらに加速させる取り組みです.
なぜLLMに強化学習が必要なのか?その理由は,LLMが単に膨大なテキストデータを学習するだけでは,必ずしも人間にとって望ましい応答や動作を保証できないからです.強化学習は,モデルが人間のフィードバックや報酬に基づいて振る舞いを改善するための強力な手法として注目されています.たとえば,ChatGPTやGenimiも,RLHFを用いることで,ユーザーフレンドリーな応答生成能力を獲得しました.
このような理由から世界のBig Techや著名な研究機関ではRLHFの研究が盛んに行われており,毎日10本近くの論文が新たに投稿されています.
RLHFの概略
RLHFのプロセスは,大きく以下の3つのステップに分けられます:
-
SFT(Supervised Fine-Tuning)
LLMの事前学習モデルに対して,手動でラベル付けされたデータを用いて教師あり学習を行います.これにより,LLMにChatの形式に従って応答することを学習させます.
-
報酬モデル(Reward Model, RM)の学習
モデルが生成する複数の応答候補に対して,人間が「どれが優れているか」を評価し,優劣関係(ランキング)をデータとして収集します.人間が好んだ応答を,”chosen response”や”win response”,人間が好まなかったものを”rejected response”, “lose response”と呼びます.このアノテーションデータを用いて,応答の良し悪しをスコア化する報酬モデルを構築します.人間が好んだ応答の報酬を人間が好まなかった応答よりも高く出力するようにモデルを学習させます.
-
強化学習(Reinforcement Learning)
報酬モデルを利用して,LLMを強化学習で最適化します.このステップでは,PPO(Proximal Policy Optimization [Schulman+, 2017])などのアルゴリズムがよく使われます.報酬を最大化するような応答を学習させることで,LLMを人間の好みにチューニングします.
また,最近注目されている手法に,DPO(Direct Preference Optimization [Rafailov+, 2023]) があります.DPOは,報酬モデルと強化学習のプロセスを統合し,より簡単にモデルを最適化する新しいアプローチです.
trl (Transformers Reinforment Learing)
trl
は,RLHFやその派生手法のためのライブラリで,以下の特徴を持っています:
- RLHFに必要な主要な手法(PPO, DPO, SFTなど)が実装されている
- Hugging FaceのTransformersとの統合が進んでおり,事前学習済みモデルを簡単に活用できる
- 学習プロセスに必要なツールが揃っているため,試行錯誤が容易
一方で分野の進展に合わせてその進化のスピードは凄まじく,1ヶ月もすればライブラリが更新され,仕様が変更されます.(このブログを書いている最中にもtrlのupdateがあり筆者は涙目になりました)
このライブラリの早い変化に追従するためには,trlライブラリで変わること,変わらないことを認識することがとても重要です.このブログでは,そんな変わること,変わらないこと(変わりにくいこと)を紹介します.trlのexampleを眺めつつ,このブログを参照していただけると嬉しいです.
変わらないこと(変わりにくいこと)
LLMはChatを知らない
LLMは基本的に「次のトークンを予測する確率モデル」として動作します.つまり、文字列をトークンに変換し,それまでのトークン列を基に次のトークンを予測・生成する仕組みです.このモデルには,文字列がどのように構造化されているかを理解する能力は備わっていないため,特に「ユーザーからの入力(プロンプト)」と「それに対する応答」を明示的に区別する必要があります.
たとえば、以下のような生の文字列がある場合を考えてみましょう:
強化学習とは何ですか? 強化学習は、エージェントが報酬を最大化する行動を学習する方法です。
この文字列だけでは、「どちらが,どこまでがユーザーの発言で、どちらが,どこまでがモデルの応答なのか」をLLMが正確に理解するのは難しいです.特に,ユーザーの入力とモデルの応答を明示的に区別するための情報が不足しています.
そこで、次のように「発話者」を明示する形式に変換します:
User: 強化学習とは何ですか?
Assistant: 強化学習は、エージェントが報酬を最大化する行動を学習する方法です。
この形式に変換することで,LLMは「User」が質問し,「Assistant」がそれに応答していることを認識できるようになります.この変換によって,モデルはプロンプト(ユーザーの入力)と応答を正しく区別し,より適切な学習や推論を行うことが可能になります.
LLMはそれぞれのChat形式を持つ
現在公開されているLLMの多くにはこの会話の形式が予め設定されています.言語モデルはモデルとセットで文字列をトークンに変換するTokenizerを持ちます.このTokenizerが持つapply_chat_template
を使うことで,そのLLMに設定された会話形式に変換することができます.
では,apply_chat_temple
を使うためにはどうすれば良いのでしょうか?その前にも当然前処理が必要です.具体的には,”role”として,誰の発話なのか,”content”として,それがどんな内容かという「会話の形式」に変換します.
具体例
messages = [
{"role": "system", "content": "強化学習とは何ですか?"},
{"role": "user", "content": "強化学習は、エージェントが報酬を最大化する行動を学習する方法です。"}
]
このような形式に変換した後で,Tokenizerのapply_chat_template
を適用します.
例えばCALM3は以下のような形式に変換してくれます.
<|im_start|>system
強化学習とは何ですか?<|im_end|>
<|im_start|>user
強化学習は、エージェントが報酬を最大化する行動を学習する方法です。<|im_end|>
<|im_start|>assistant
このように「会話の形式」をそのモデルに設定された形式に変換できます.
Trainerごとのデータセット形式
学習を開始する前にデータセットは,TypesとFormatを手法に合わせる必要があります.trlのドキュメントには手法ごとに必要なフォーマットがまとめれています.
例えばTypesは手法ごとに異なり,SFTでは,Language modeling (指示と応答) が,報酬モデル(Reward Model, RM)やDPOの学習では,Preference(指示と返答が2つ,人間がどちらを好んだかのラベル)が,PPOではPrompt-only(指示のみ) が必要です.詳しくWhich dataset type to use?を確認してください.
またFormatは「会話の形式(Conversational format)」のものもしくは「apply_chat_template
を適用したStandard format」である必要があります.例えばpreferenceでは以下のようになります.
# Standard format
## Explicit prompt (recommended)
preference_example = {"prompt": "The sky is", "chosen": " blue.", "rejected": " green."}
# Implicit prompt
preference_example = {"chosen": "The sky is blue.", "rejected": "The sky is green."}
# Conversational format
## Explicit prompt (recommended)
preference_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}],
"chosen": [{"role": "assistant", "content": "It is blue."}],
"rejected": [{"role": "assistant", "content": "It is green."}]}
## Implicit prompt
preference_example = {"chosen": [{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."}],
"rejected": [{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is green."}]}
Alpaca Farm
具体例:RLHFのベンチマークとしてよく用いられるデータセットがAlpaca Farmです.Alpaca-farmのデータセットはLanguage modeling Typeのものが次の形式で,
Preference Typeのものが次の形式で与えられています.
Alpaca farmのデータセットをtrlで使用したい場合は,データセットの形式を変更する必要があります.Alpaca farmの中身を見ると,指示 (prompt)がInstructionとInputに分かれていることがわかります.そのためLanguage modeling Type場合は次のような処理が必要になります.この例ではAlpaca Farmデータセットを「会話の形式(Conversational format)」に変換しています.
prompt_format_dict = {
"prompt_noinputs": "{instruction}",
"prompt_inputs": "{instruction} \n{input}"
}
def _prompt_format(example, prompt_dict):
formatted_expample = ""
if example["input"] is None or len(example["input"]) == 0:
formatted_expample = prompt_dict["prompt_noinputs"].format_map(example)
else:
formatted_expample = prompt_dict["prompt_inputs"].format_map(example)
return formatted_expample
def alpaca_farm_prompt_response_formatting_function(examples):
examples_list = expand_dict_to_list_of_dicts(examples)
prompts = [_prompt_format(example, prompt_format_dict) for example in examples_list]
messages = [
[
{"role": "user", "content": prompt},
{"role": "assistant", "content": output},
]
for prompt, output in zip(prompts, examples["output"])
]
new_examples = {
"text": messages,
}
return new_examples
from datasets import load_dataset
dataset = load_dataset("tatsu-lab/alpaca_farm", name="alpaca_instructions", split="sft")
dataset = dataset.map(alpaca_farm_prompt_response_formatting_function, batched=True)
Preference Typeは次のように変換します.
prompt_format_dict = {
"prompt_noinputs": "{instruction}",
"prompt_inputs": "{instruction} \n{input}"
}
def _prompt_format(example, prompt_dict):
formatted_expample = ""
if example["input"] is None or len(example["input"]) == 0:
formatted_expample = prompt_dict["prompt_noinputs"].format_map(example)
else:
formatted_expample = prompt_dict["prompt_inputs"].format_map(example)
return formatted_expample
def alpaca_farm_pairwise_formatting_function(examples):
examples_list = expand_dict_to_list_of_dicts(examples)
prompts = [_prompt_format(example, prompt_format_dict) for example in examples_list]
chosen_messages = []
rejected_messages = []
for prompt, output_1, output_2, preference in zip(
prompts, examples["output_1"], examples["output_2"], examples["preference"]
):
chosen, rejected = (output_1, output_2) if preference == 1 else (output_2, output_1)
chosen_messages.append(
[
{"role": "user", "content": prompt},
{"role": "assistant", "content": chosen},
]
)
rejected_messages.append(
[
{"role": "user", "content": prompt},
{"role": "assistant", "content": rejected},
]
)
new_examples = {
"prompt": prompts,
"chosen": chosen_messages,
"rejected": rejected_messages,
}
return new_examples
from datasets import load_dataset
dataset = load_dataset("tatsu-lab/alpaca_farm", name="alpaca_human_preference", split="preference")
dataset = dataset.map(alpaca_farm_pairwise_formatting_function, batched=True)
このようによく使われるデータセットでも事前にapply_chat_template
を適用可能な形「会話の形式(Conversational format)」にフォーマットする必要があります.このような処理の必要性は今後もあまり変わることはないでしょう.
変わること
trl
の進化速度は非常に速く,よく破壊的な変更が行われます.実際に筆者がこのブログを執筆中にもv0.13.0がリリースされましたが,断腸の思いで無視しました.ではどんな部分がこの破壊的変更の影響を受けるのでしょうか?おそらくTrainerが受け取るデータセットのTypeが変わることはほとんどありません.手法自体が変わらない限り,その手法に必要なデータが変わることはありませんし,そんなことはほとんど起こらないでしょう.
一方で,受け取れるFormat,上記ではStandard formatやConversational formatと紹介した部分が変わる可能性が大いにあります.v0.12.2現在はtrlがStandardと呼ぶ形式に変換しておけば問題はありません.しかし,過去trlは手法によってこれに違いがありました.具体的には,SFTはtokenizeしてからデータセットをTrainerに渡す必要がある,もしくはそのkeyを指定しなければいけない一方で,DPOなどは現在と変わらないStandard形式でした.しかも,それに関するドキュメントにまとまっておらず,exampleから推測する必要がありました.さらにその形式は常に変わっており,今まで問題のなかった形式がversion upがあると途端にダメになって学習不能になることもよく起こります.
今はドキュメントがあるから・整備されているから安心かといえばそうではありません.trlのドキュメントには次のような記述があります.つまり「会話形式を受け付けない」とっているわけですね.
- TRL trainers only support standard dataset formats, for now. If you have a conversational dataset, you must first convert it into a standard format. For more information on how to work with conversational datasets, refer to the Working with conversational datasets in TRL section.
しかしDPOのソースコードを確認すると会話形式も対応していることが分かります.ドキュメントだけを信じず,常にコードをチェックすることが重要です.
終わりに
このブログでは,RLHFとそれを実現するためのライブラリtrl
について,変わること・変わらないことを解説しました。trl
は非常に便利で強力なツールですが,仕様変更の頻度が高く,最新の情報を追い続ける必要があります.特に公式ドキュメントとソースコードを両方確認する習慣をつけることで,予期せぬ問題を防ぐことができます.
RLHFの研究や実装に挑戦する中で、このブログが皆さんの参考になれば幸いです。また,これからもこの分野は急速に進化し続けるため,今後のアップデートにも注目していきましょう.
[付録]RLHFの評価
文章の生成
RLHFを用いた学習の評価は,単にlossが下がっているかだけでは不十分です.実際にモデルが生成する応答を評価する必要があります.モデルに文章を生成するときには,apply_chat_template
する際にadd_generation_prompt=True
を指定してここから先が返答だとわかりやすくする必要があります.
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
model = AutoModelForCausalLM.from_pretrained("cyberagent/calm3-22b-chat", device_map="auto", torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained("cyberagent/calm3-22b-chat")
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
messages = [
{"role": "system", "content": "あなたは親切なAIアシスタントです。"},
{"role": "user", "content": "AIによって私たちの暮らしはどのように変わりますか?"}
]
input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
output_ids = model.generate(input_ids,
max_new_tokens=1024,
temperature=0.5,
streamer=streamer)
評価方法
生成した文章に対して以下のような評価方法がよく使われます:
-
GPT-4評価
他のオープンソースモデルと比較して性能を評価する手法です.OpenAIのGPT-4を用いて生成テキストを評価します.多くのベンチマークでこの方法が用いられます.ベンチマークとして有名なものにAlpaca Eval, MT benchがあります.
-
報酬モデルによる評価(研究ではおすすめ)
公開されている報酬モデルを用いて生成した文章をスコアリングします.この方法はAPIコストがかからず,再現性が高いのがメリットです.具体的な手順は,prompt-onlyのデータセットから応答文を生成し,その文章に対してreward modelでスコアを計算します.reward modelはreward benchからいい感じのモデルを選択しましょう.
-
人間評価
面倒ですが,これが一番です.面倒くさがらずに人手で学習したモデルを評価しましょう.
人間の評価ってどう集めるの?
RLHFには,人間の評価が欠かせません.その収集に役立つのが,Label Studio です.Label Studioは,オープンソースのアノテーションツールで,RLHF用のカスタマイズされたUIを簡単に構築できます.モデルの出力を比較し,どちらが優れているかを簡単にラベル付けすることが可能です.
Discussion