特化型llm(Doujinshi-1.8b)の開発報告書⑤:trlライブラリを用いたSFT(教師ありファインチューニング)
はじめに
沼津高専のpuwaerです。この度、R18に特化した大規模言語モデル(LLM)、Doujinshi-1.8bを開発しました。
この記事では、trlライブラリを用いたSFT(教師ありファインチューニング)の作成方法をプログラムを交えて解説します。
本記事では、trlライブラリを用いたSFT(教師ありファインチューニング)の手法について詳しく解説します。
また、以下のgithubリポジトリを参考に開発を進めました。 なお、trlライブラリの一部の引数が変更されており、バージョンが合わず動作しない問題が発生したため、修正した内容を紹介します。
- github: llm-jp-sft
下の流れでする合成データの具体的な手法について解説します。
1.使用機器
2.環境構築
3.プログラムの詳細
4.テータセットの形式
5.学習の実行
6.パラメータの解説
本記事で利用するプログラムは以下のGitHubリポジトリで管理しています。
このプログラムで作成したモデル
- huggingface: Doujinshi-1.8b-instruct
1.使用機器
以下のようなpcを使い開発しました。
高専4年生でまだ研究室配属されていないため教授に頼みpcを貸してもらいました。
また、個人のpcはdeepspeedを用いたllmの継続事前学習の検証用にramが増設されています。
機器 | CPU | RAM | GPU | VRAM |
---|---|---|---|---|
研究室PC | Intel i7-12700 | 64GB | NVIDIA RTX A6000 | 48GB |
個人PC | Ryzen 7 5700X | 96GB | NVIDIA GeForce RTX 3060 | 12GB |
2. 環境構築
anaconda環境の作成と有効化
conda create -n sft python=3.11 -y
conda activate sft
必要なパッケージのインストール
cd document/sft
pip install -r requirements.in
環境変数の設定
{path/to/your/miniconda3}
または {path/to/your/anaconda3}
を実際のパスに置き換えてください。
export LD_LIBRARY_PATH={path/to/your/miniconda3}/envs/sft/lib/python3.11/site-packages/nvidia/nvjitlink/lib:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH={path/to/your/anaconda3}/envs/sft/lib/python3.11/site-packages/nvidia/nvjitlink/lib:$LD_LIBRARY_PATH
追加パッケージのインストール
pip install flash-attn --no-build-isolation
pip install --upgrade accelerate
pip install datasets
3. sft用の学習コード
trlライブラリを用いた学習用コードです。
src/train_chat.py
を示す
import logging
from dataclasses import dataclass
from typing import Optional, List
import torch
import wandb # W&B をインポート
from peft import LoraConfig
from datasets import disable_caching, load_dataset, concatenate_datasets
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
HfArgumentParser,
BitsAndBytesConfig,
)
from trl import SFTTrainer
disable_caching()
logger = logging.getLogger(__name__)
@dataclass
class SFTTrainingArguments:
model_name_or_path: str
data_files: List[str]
eval_data_files: Optional[List[str]] = None
tokenizer_name_or_path: Optional[str] = None
use_fast: bool = True
additional_special_tokens: Optional[List[str]] = None
max_seq_length: int = 4096
load_in_8bit: bool = False
load_in_4bit: bool = False
use_flash_attention_2: bool = False
use_peft: bool = False
peft_target_model: Optional[str] = "llm-jp"
peft_target_modules: Optional[List[str]] = None
peft_lora_r: int = 8
peft_lora_alpha: int = 32
peft_lora_dropout: float = 0.05
# W&B 関連の引数を追加
wandb_project: Optional[str] = None
wandb_run_name: Optional[str] = None
wandb_log_steps: Optional[int] = None
def __post_init__(self):
if self.load_in_8bit and self.load_in_4bit:
raise ValueError("load_in_8bit and load_in_4bit are mutually exclusive")
if self.peft_target_model and self.peft_target_modules is None:
if self.peft_target_model == "llm-jp":
self.peft_target_modules = ["c_attn", "c_proj", "c_fc"]
elif self.peft_target_model == "llama":
self.peft_target_modules = [
"q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"
]
elif self.peft_target_model == "llama-all":
self.peft_target_modules = [
"q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", "lm_head", "embed_tokens"
]
else:
logger.warning(
f"peft_target_model '{self.peft_target_model}' is not supported, "
f"so peft_target_modules is set to None."
)
def from_pretrained_kwargs(self, training_args):
if self.load_in_8bit:
kwargs = {"load_in_8bit": True}
elif self.load_in_4bit:
kwargs = {
"quantization_config": BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
}
elif training_args.bf16:
kwargs = {"torch_dtype": torch.bfloat16}
else:
kwargs = {"torch_dtype": torch.float16}
kwargs["use_flash_attention_2"] = self.use_flash_attention_2
return kwargs
def load_datasets(data_files, tokenizer, max_seq_length=2048):
datasets = []
for data_file in data_files:
dataset = load_dataset("json", data_files=data_file)
dataset = dataset["train"]
def tokenize_function(example):
tokenized = tokenizer.apply_chat_template(
example["messages"],
tokenize=True,
add_generation_prompt=False,
truncation=True,
max_length=max_seq_length,
return_tensors="pt",
return_dict=True
)
return {
"input_ids": tokenized["input_ids"].squeeze(0).tolist(),
"attention_mask": tokenized["attention_mask"].squeeze(0).tolist()
}
dataset = dataset.map(tokenize_function, remove_columns=dataset.column_names)
datasets.append(dataset)
return concatenate_datasets(datasets)
def main():
parser = HfArgumentParser((TrainingArguments, SFTTrainingArguments))
training_args, sft_training_args = parser.parse_args_into_dataclasses()
# W&B の設定を適用
if sft_training_args.wandb_project:
training_args.report_to = ["wandb"] # W&B にログを送信
wandb.init(
project=sft_training_args.wandb_project,
name=sft_training_args.wandb_run_name,
config=vars(training_args), # TrainingArguments を W&B に記録
)
if sft_training_args.wandb_log_steps:
training_args.logging_steps = sft_training_args.wandb_log_steps
tokenizer_name_or_path = (
sft_training_args.tokenizer_name_or_path or sft_training_args.model_name_or_path
)
logger.info(f"Loading tokenizer from {tokenizer_name_or_path}")
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name_or_path,
use_fast=sft_training_args.use_fast,
additional_special_tokens=sft_training_args.additional_special_tokens,
trust_remote_code=True,
)
chat_template = (
"{{bos_token}}{% for message in messages %}"
"{% if message['role'] == 'user' %}{{ '\\n\\n### 指示:\\n' + message['content'] }}"
"{% elif message['role'] == 'system' %}{{ '以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。' }}"
"{% elif message['role'] == 'assistant' %}{{ '\\n\\n### 応答:\\n' + message['content'] + eos_token }}"
"{% endif %}"
"{% if loop.last and add_generation_prompt %}{{ '\\n\\n### 応答:\\n' }}{% endif %}"
"{% endfor %}"
)
tokenizer.chat_template = chat_template
logger.info("Custom chat template applied to tokenizer")
if tokenizer.bos_token is None:
tokenizer.bos_token = "<|begin_of_text|>"
logger.info(f"Set default bos_token: {tokenizer.bos_token}")
if tokenizer.eos_token is None:
tokenizer.eos_token = "<|end_of_text|>"
logger.info(f"Set default eos_token: {tokenizer.eos_token}")
logger.info("Loading data")
train_dataset = load_datasets(sft_training_args.data_files, tokenizer, sft_training_args.max_seq_length)
if sft_training_args.eval_data_files:
eval_dataset = load_datasets(sft_training_args.eval_data_files, tokenizer, sft_training_args.max_seq_length)
training_args.do_eval = True
else:
eval_dataset = None
logger.info(f"Loading model from {sft_training_args.model_name_or_path}")
kwargs = sft_training_args.from_pretrained_kwargs(training_args)
model = AutoModelForCausalLM.from_pretrained(
sft_training_args.model_name_or_path,
trust_remote_code=True,
**kwargs,
)
peft_config = None
if sft_training_args.use_peft:
logger.info("Setting up LoRA")
peft_config = LoraConfig(
r=sft_training_args.peft_lora_r,
target_modules=sft_training_args.peft_target_modules,
lora_alpha=sft_training_args.peft_lora_alpha,
lora_dropout=sft_training_args.peft_lora_dropout,
fan_in_fan_out=True,
bias="none",
task_type="CAUSAL_LM",
)
if training_args.gradient_checkpointing:
for param in model.parameters():
param.requires_grad = False
if param.ndim == 1:
param.data = param.data.to(torch.float32)
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
logger.info("Setting up trainer")
trainer = SFTTrainer(
model=model,
args=training_args,
processing_class=tokenizer,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
)
logger.info("Training")
trainer.train()
logger.info("Saving model")
trainer.save_model()
# W&B を終了
if sft_training_args.wandb_project:
wandb.finish()
if __name__ == "__main__":
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s %(name)s:%(lineno)d: %(levelname)s: %(message)s",
)
main()
4. データセットの形式
トレーニングデータは、JSONLファイルの形式で提供されます。各行は、以下のような構造のJSONオブジェクトです:
{
"ID": "magpie_sft_v1.0-0",
"messages": [
{"role": "system", "content": "以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。"},
{"role": "user", "content": "日本の伝統的な文化や風習について、特に地域ごとの違いや特徴を知りたいと思っています。..."},
{"role": "assistant", "content": "もちろん、その2つの祭りの特徴について詳しく説明いたします。..."}
]
}
各メッセージには、「role」と「content」の2つのフィールドがあります。「role」は「system」、「user」、または「assistant」のいずれかです。「content」は、メッセージの内容です。
5. 学習の実行
実行コマンド例
学習を実行する際は、データセットとモデルを指定したパスに格納してください。ここでは、通常のsft(Supervised Fine-Tuning)と、chat_template
を使用したチャット形式のトレーニングの実行コマンドを紹介します。なお、この記事で示したPythonコードはchat_template
を使用したもので、src/train_chat.py
に記載されています。通常のSFT用のコードはsrc/train.py
にあります。
1. 通常のトレーニング
通常のSFTを実行する場合のコマンドです。
python src/train.py \
--num_train_epochs 1 \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 8 \
--learning_rate 1e-5 \
--warmup_ratio 0.1 \
--lr_scheduler_type cosine \
--bf16 \
--max_seq_length 4096 \
--data_files ./data/sample_data.jsonl \
--model_name_or_path ./model/sample_model \
--output_dir results/
chat_template
を使ったチャットトレーニング
2. src/train_chat.py
では、以下のchat_template
が使用されています。このテンプレートにより、ユーザー、システム、アシスタントの役割に応じた会話形式のデータが生成されます。
chat_template = (
"{{bos_token}}{% for message in messages %}"
"{% if message['role'] == 'user' %}{{ '\n\n### 指示:\n' + message['content'] }}"
"{% elif message['role'] == 'system' %}{{ '以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。' }}"
"{% elif message['role'] == 'assistant' %}{{ '\n\n### 応答:\n' + message['content'] + eos_token }}"
"{% endif %}"
"{% if loop.last and add_generation_prompt %}{{ '\n\n### 応答:\n' }}{% endif %}"
"{% endfor %}"
)
実行コマンド:
python src/train_chat.py \
--num_train_epochs 1 \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 8 \
--learning_rate 5e-5 \
--warmup_ratio 0.1 \
--lr_scheduler_type cosine \
--bf16 \
--data_files ./data/sample_data.jsonl \
--model_name_or_path ./model/sample_model \
--output_dir results/output_model/ \
--wandb_project "sample-sft" \
--wandb_run_name "test_1" \
--wandb_log_steps 10
6.パラメータの解説
基本的なパラメータの解説は以下の記事でしています。
sftでは以下のようなlearnig_rateを使用することが多い
以下のパラメータで指定する
設定項目 | 説明 |
---|---|
learning_rate | 学習率を5e-5 (0.00005)に設定し、モデルの更新ステップを細かく制御。 |
warmup_ratio | ウォームアップ期間を全体の10%(0.1 )に設定し、学習率を徐々に増加させる比率。 |
lr_scheduler_type |
cosine スケジューラを指定し、学習率をコサイン関数に基づいて徐々に減少させる方式。 |
use_peft | falseの場合はフルパラメータ学習を行い、trueの場合はLoRA学習を行います。 (デフォルトはfalse) |
おわりに
本記事では、trlライブラリを用いたSFT(教師ありファインチューニング)の手法を解説しました。
sftの記事はたくさんありますが、大体がlora学習なためフルパラメータ学習は珍しいと思います。また、小さいモデルだとchat_templateを使用するより精度が上がるのでおすすめです。
また、分からないことがありましたら、気軽にTwitterのDMに質問してください。
開発支援のお願い
現在、開発を続けていますが、クラウドGPUの価格が高く、十分な計算リソースを確保できずにいます。そのため、思い通りに開発が出来ていません。
また、オープンソースの理念を大切にしており、プログラム・データセット・モデルを有料で公開するつもりはありません。そのため、金銭的に余裕のある方に支援していただけると大変助かります。
TwitterのDMやご支援いただける方は、以下のプラットフォームよりお願いいたします。
Discussion