🐕
huggingfaceのTRLについて
そもそもTRLとは?
- モデルをファインチューニングするための一連のツールを提供するライブラリ
TRLの手法
Trainer系
SFTTrainer(教師ありファインチューニング):
- SFTTrainerを使って、言語モデルやアダプターをカスタムデータセットで簡単にファインチューニングできます。
from transformers import AutoModelForCausalLM
from datasets import load_dataset
from trl import SFTTrainer
dataset = load_dataset("imdb", split="train")
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=512,
)
trainer.train()
RewardTrainer(人間のフィードバックからの強化学習):
- RewardTrainerを使って、人間の好みに基づいた報酬モデルを簡単に訓練できます。
from peft import LoraConfig, task_type
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from trl import RewardTrainer, RewardConfig
model = AutoModelForSequenceClassification.from_pretrained("gpt2")
peft_config = LoraConfig(
task_type=TaskType.SEQ_CLS,
inference_mode=False,
r=8,
lora_alpha=32,
lora_dropout=0.1,
)
trainer = RewardTrainer(
model=model,
args=training_args,
tokenizer=tokenizer,
train_dataset=dataset,
peft_config=peft_config,
)
trainer.train()
PPOTrainer:
- PPOTrainerを使って、PPOアルゴリズムを用いてファインチューニングした言語モデルをさらに最適化できます
* https://huggingface.co/docs/trl/ppo_trainer
from trl import PPOTrainer
ppo_trainer = PPOTrainer(
model=model,
config=config,
train_dataset=train_dataset,
tokenizer=tokenizer,
)
DPOTrainer:
- DPOTrainerを使って、直接嗜好最適化(DPO)で訓練できます。(まだよくわかっていません)
* https://huggingface.co/blog/dpo-trl
* https://huggingface.co/docs/trl/dpo_trainer
dpo_trainer = DPOTrainer(
model,
model_ref,
args=training_args,
beta=0.1,
train_dataset=train_dataset,
tokenizer=tokenizer,
)
dpo_trainer.train()
よくわからない系
TextEnvironment:
- TextEnvironmentを使って、RLのツールとともにモデルを訓練できます。
- LangChainnのAgentのようなもの?
- https://huggingface.co/docs/trl/text_environments
Best-of-N Sampling:
- アクティブなモデルから予測をサンプリングする別の方法として、Best of Nサンプリングを使えます。
- LLMをfine-turningせず、複数の出力を吐き出させる方法?
参考サイト
- 公式サイト
- RLHF を使用して LLaMA をトレーニングするための実践ガイド
Discussion