🦄
ローカルLLMの応答文を好みの出力に整えるDPOの学習
背景と目的
- LLMの応答文をユーザーの嗜好に沿った応答に整える手法が提案されています。
- LLMは従来ユーザーの入力に沿った応答を返すように微調整(SFT,Supervised Fine Tuning)されていますが、しばしば応答が淡白であったり、品質の低い応答を返すなど、ユーザーの嗜好に沿わない応答文を生成することがあります。
- この問題に対して、従来はOpenAIが報告したRLHFのように報酬モデルを学習し、好みを実数で表現する関数を獲得した後、この好みを最大化するような応答が得られるように強化学習することが必要でした。
- しかし、上記のDPOという手法では、RLHFの定式化から、式変形と学習後の報酬モデルに、ある制約項を追加することで、報酬モデルなしに好みの応答文を学習できることを理論的に示しました。
- その効果を実際に実装して確かめてみましょう。
動作環境
検証機 | GPU | OS |
---|---|---|
Paperspaceのインスタンスを利用 | NVIDIA RTX A6000 48GB | Ubuntu22.04 |
torch==2.1.2+cu123
trl==0.7.4
transformers==4.37.2
peft==0.6.2
datasets==2.15.0
bitsandbytes==0.42.0
wandb==0.16.3
学習モデル
llm-jp/llm-jp-1.3b-v1.0
のフルパラメータチューニングを例として使用します。
学習データ
shi3z/anthropic_hh_rlhf_japanese
- 簡単化のため、160,000件中80,000件を無作為に抽出して学習します。
学習時情報
学習時間 | VRAM使用量 |
---|---|
3hours49min | 35165MiB |
ソースコード
指定したモデルをDPOでフルパラメータチューニングするクラスDPOTrainerKit
を実装しました。学習が終了した際にはwandb
からメールで通知が届くようになっています。
DPOTrainerKitクラス(263行)
import json
import wandb
import torch
from typing import Dict
from datasets import Dataset, load_dataset
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers import TrainingArguments
from trl import DPOTrainer
class DPOTrainerKit(object):
def __init__(
self,
project_name: str="llm_jp_1b",
model_name: str="llm-jp/llm-jp-1.3b-v1.0",
output_model_dir:str="./out/dpo_model",
output_dir: str="./out/dpo_logs",
fp16: bool=False,
bf16: bool=True,
num_train_epochs: int=1,
dataloader_num_workers: int=4,
save_total_limit: int=1,
push_to_hub: bool=False,
auto_find_batch_size: bool=False,
per_device_train_batch_size: int=8,
gradient_accumulation_steps: int=4,
optim: str="adamw_torch",
# optim: str="paged_adamw_32bit",
learning_rate: float=5e-4,
lr_scheduler_type: str="cosine",
max_grad_norm: float=0.3,
warmup_ratio:float=0.03,
weight_decay:float=0.001,
save_steps:int=50,
logging_steps:int=50,
report_to:str="wandb",
beta:float = 0.1,
max_length:int=300,
max_prompt_length:int=300
):
wandb_config = {
"model_name": project_name,
"epoch": num_train_epochs,
"optim": optim,
"learning_rate": learning_rate,
"scheduler": lr_scheduler_type,
"max_outout_length": max_length,
"max_prompt_length": max_prompt_length
}
wandb.login(ご自身のログイン用のAPI keyに差し替えてください)
wandb.init(
project=f"{project_name}_dpo",
name=project_name,
config=wandb_config
)
self.output_dir = output_dir
self.output_model_dir = output_model_dir
self.fp16 = fp16
self.bf16 = bf16
self.num_train_epochs = num_train_epochs
self.dataloader_num_workers = dataloader_num_workers
self.save_total_limit = save_total_limit
self.push_to_hub = push_to_hub
self.auto_find_batch_size = auto_find_batch_size
self.per_device_train_batch_size = per_device_train_batch_size
self.gradient_accumulation_steps = gradient_accumulation_steps
self.optim = optim
self.learning_rate = learning_rate
self.lr_scheduler_type = lr_scheduler_type
self.max_grad_norm = max_grad_norm
self.warmup_ratio = warmup_ratio
self.weight_decay = weight_decay
self.save_steps = save_steps
self.logging_steps = logging_steps
self.report_to = report_to
self.beta = beta
self.max_length = max_length
self.max_prompt_length = max_prompt_length
self.model_name = model_name
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
load_in_8bit=False,
load_in_4bit=False,
low_cpu_mem_usage=True,
device_map={"": 0} # モデル全体をGPU0にロード
)
self.model.config.use_cache = False # キャッシュ (学習時はFalse)
self.model.config.pretraining_tp = 1 # 事前学習で使用したテンソル並列ランク
self.model_ref = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
load_in_8bit=False,
load_in_4bit=False,
low_cpu_mem_usage=True,
device_map={"": 0} # モデル全体をGPU0にロード
)
self.model_ref.config.pretraining_tp = 1 # 事前学習で使用したテンソル並列ランク
self.train_dataset, self.eval_dataset = build_dataset()
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
use_fast=False, # Fastトークナイザーの有効化
add_eos_token=True, # データへのEOSの追加を指示
trust_remote_code=True
)
self.tokenizer.pad_token = self.tokenizer.unk_token
self.tokenizer.padding_side = "right" # fp16でのオーバーフロー問題対策
return
def notify(self):
message = {
"model_name": self.model_name,
"output_dir": self.output_dir,
"bf16": self.bf16,
"num_train_epochs": self.num_train_epochs,
"dataloader_num_workers": self.dataloader_num_workers,
"save_total_limit": self.save_total_limit,
"push_to_hub": self.push_to_hub,
"auto_find_batch_size": self.auto_find_batch_size,
"per_device_train_batch_size": self.per_device_train_batch_size,
"gradient_accumulation_steps": self.gradient_accumulation_steps,
"optim": self.optim,
"learning_rate": self.learning_rate,
"lr_scheduler_type": self.lr_scheduler_type,
"max_grad_norm": self.max_grad_norm,
"warmup_ratio": self.warmup_ratio,
"weight_decay": self.weight_decay,
"save_steps": self.save_steps,
"logging_steps": self.logging_steps,
"report_to": self.report_to
}
message = json.dumps(message, indent=4)
wandb.alert(
title='学習が完了しました🎉',
text=f"{message}",
)
wandb.finish()
return
def run(self):
training_args = TrainingArguments(
output_dir=self.output_dir,
fp16=self.fp16,
bf16=self.bf16,
num_train_epochs=self.num_train_epochs,
dataloader_num_workers=self.dataloader_num_workers,
save_total_limit=self.save_total_limit,
push_to_hub=self.push_to_hub,
auto_find_batch_size=self.auto_find_batch_size,
# max_steps=1000, # 学習ステップ数
per_device_train_batch_size=self.per_device_train_batch_size,
gradient_accumulation_steps=self.gradient_accumulation_steps,
optim=self.optim,
learning_rate=self.learning_rate,
lr_scheduler_type=self.lr_scheduler_type,
max_grad_norm=self.max_grad_norm, # 最大法線勾配 (勾配クリッピング)
warmup_ratio=self.warmup_ratio, # 線形ウォームアップのステップ比率 (0から学習率まで)
weight_decay=self.weight_decay, # bias/LayerNormウェイトを除く全レイヤーに適用するウェイト減衰
save_steps=self.save_steps, # 何ステップ毎にチェックポイントを保存するか
logging_steps=self.logging_steps, # 何ステップ毎にログを記録するか
report_to=self.report_to # レポート
)
dpo_trainer = DPOTrainer(
self.model,
self.model_ref,
args=training_args,
beta=self.beta,
max_length=self.max_length,
max_prompt_length=self.max_prompt_length,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
tokenizer=self.tokenizer,
loss_type='sigmoid' # DPO'sigmoid'/SLiC'hinge'/IPO'ipo'/HALOs'kto'
)
dpo_trainer.train()
dpo_trainer.model.save_pretrained(self.output_model_dir)
self.notify()
return
def build_dataset():
train_dataset = get_hh("train", sanity_check=True)
eval_dataset = get_hh("test", sanity_check=True)
print(train_dataset)
print(eval_dataset)
print("--prompt--\n", train_dataset[2]["prompt"])
print("--chosen--\n", train_dataset[2]["chosen"])
print("--rejected--\n", train_dataset[2]["rejected"])
return train_dataset, eval_dataset
def get_hh(
split: str,
sanity_check: bool=False,
silent: bool=False,
cache_dir: str=None
) -> Dataset:
dataset = load_dataset(
"shi3z/anthropic_hh_rlhf_japanese",
"train",
cache_dir
)
dataset = dataset["train"].train_test_split(
test_size=0.025,
shuffle=False
)[split]
if sanity_check:
dataset = dataset.select(range(min(len(dataset), 80000)))
def split_prompt_and_responses(sample) -> Dict[str, str]:
return extract_anthropic_prompt(sample)
return dataset.map(split_prompt_and_responses)
def extract_anthropic_prompt(sample):
text0 = sample["chosen"]
text0 = text0.replace("\\n\\n人間:", "User: ")
text0 = text0.replace("\\n\\nAssistant:", "\n\nAssistant: ")
text0 += "<EOD|LLM-jp>"
text1 = sample["rejected"]
text1 = text1.replace("\\n\\n人間:", "User: ")
text1 = text1.replace("\\n\\nAssistant:", "\n\nAssistant: ")
text1 += "<EOD|LLM-jp>"
search_term = "\n\nAssistant: "
search_term_idx0 = text0.rfind(search_term)
search_term_idx1 = text1.rfind(search_term)
return {
"prompt": text0[: search_term_idx0 + len(search_term)],
"chosen": text0[search_term_idx0 + len(search_term):],
"rejected": text1[search_term_idx1 + len(search_term):],
}
if __name__ == "__main__":
kit = DPOTrainerKit()
kit.run()
所感
インストラクションチューニングなしのモデルに対するDPOなので、応答結果は支離滅裂なものとなりました。応答例は割愛させていただきます。
また、今回密かに、インストラクションチューニングなしにいきなりDPOをすればLLMのチューニングは事足りるのではないかという仮説を隠し持っていたのですが、どうやらインストラクションチューニングは必要であることが経験的にですが分かりました。
本来であれば、インストラクションチューニングしたモデルに対して、DPOTrainerKitで追加学習すると嗜好性のある応答が確認できると思います。
DPOは参照モデルこそ必要ではありますが、RLHFの学習を短縮する非常に優秀な学習手法です。例えば、自身の好みの応答を嗜好されるデータ(Huggingface datasetで言うchosen、つまりペルソナを持った特定の語尾やキャラクター性を持った応答)、通常のモデルの応答を嗜好されないデータ(Huggingface datasetで言うrejected)とすることで、キャラクター一貫性を持った応答文の生成が可能になると考えられます。
AITuberやキャラクターの応答文生成を検討されている皆様は、ぜひご検討していただけると面白い結果が得られるのではないかと思います。
参考
ソースコード作成にあたり、以下の記事を参考にいたしました。ありがとうございました。
Discussion