🐕

huggingfaceのTRLについて

2023/10/22に公開

そもそもTRLとは?

  • モデルをファインチューニングするための一連のツールを提供するライブラリ

TRLの手法

Trainer系

SFTTrainer(教師ありファインチューニング):


from transformers import AutoModelForCausalLM
from datasets import load_dataset
from trl import SFTTrainer

dataset = load_dataset("imdb", split="train")
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    dataset_text_field="text",
    max_seq_length=512,
)

trainer.train()

RewardTrainer(人間のフィードバックからの強化学習):


from peft import LoraConfig, task_type
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from trl import RewardTrainer, RewardConfig

model = AutoModelForSequenceClassification.from_pretrained("gpt2")
peft_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
)

trainer = RewardTrainer(
    model=model,
    args=training_args,
    tokenizer=tokenizer,
    train_dataset=dataset,
    peft_config=peft_config,
)

trainer.train()

PPOTrainer:

from trl import PPOTrainer

ppo_trainer = PPOTrainer(
    model=model,
    config=config,
    train_dataset=train_dataset,
    tokenizer=tokenizer,
)

DPOTrainer:

dpo_trainer = DPOTrainer(
    model,
    model_ref,
    args=training_args,
    beta=0.1,
    train_dataset=train_dataset,
    tokenizer=tokenizer,
)
 
dpo_trainer.train()

よくわからない系

TextEnvironment:

Best-of-N Sampling:

  • アクティブなモデルから予測をサンプリングする別の方法として、Best of Nサンプリングを使えます。

参考サイト

Discussion