Zenn
📖

特化型llm(Doujinshi-1.8b)の開発報告書⑤:trlライブラリを用いたSFT(教師ありファインチューニング)

2025/03/19に公開

はじめに

沼津高専のpuwaerです。この度、R18に特化した大規模言語モデル(LLM)、Doujinshi-1.8bを開発しました。
この記事では、trlライブラリを用いたSFT(教師ありファインチューニング)の作成方法をプログラムを交えて解説します。

本記事では、trlライブラリを用いたSFT(教師ありファインチューニング)の手法について詳しく解説します。
また、以下のgithubリポジトリを参考に開発を進めました。 なお、trlライブラリの一部の引数が変更されており、バージョンが合わず動作しない問題が発生したため、修正した内容を紹介します。

下の流れでする合成データの具体的な手法について解説します。
1.使用機器
2.環境構築
3.プログラムの詳細
4.テータセットの形式
5.学習の実行
6.パラメータの解説

本記事で利用するプログラムは以下のGitHubリポジトリで管理しています。
https://github.com/puwaer/trl_sft

このプログラムで作成したモデル

1.使用機器

以下のようなpcを使い開発しました。
高専4年生でまだ研究室配属されていないため教授に頼みpcを貸してもらいました。
また、個人のpcはdeepspeedを用いたllmの継続事前学習の検証用にramが増設されています。

機器 CPU RAM GPU VRAM
研究室PC Intel i7-12700 64GB NVIDIA RTX A6000 48GB
個人PC Ryzen 7 5700X 96GB NVIDIA GeForce RTX 3060 12GB

2. 環境構築

anaconda環境の作成と有効化

conda create -n sft python=3.11 -y
conda activate sft

必要なパッケージのインストール

cd document/sft
pip install -r requirements.in

環境変数の設定

{path/to/your/miniconda3} または {path/to/your/anaconda3} を実際のパスに置き換えてください。

export LD_LIBRARY_PATH={path/to/your/miniconda3}/envs/sft/lib/python3.11/site-packages/nvidia/nvjitlink/lib:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH={path/to/your/anaconda3}/envs/sft/lib/python3.11/site-packages/nvidia/nvjitlink/lib:$LD_LIBRARY_PATH

追加パッケージのインストール

pip install flash-attn --no-build-isolation
pip install --upgrade accelerate
pip install datasets

3. sft用の学習コード

trlライブラリを用いた学習用コードです。
src/train_chat.pyを示す

import logging
from dataclasses import dataclass
from typing import Optional, List

import torch
import wandb  # W&B をインポート
from peft import LoraConfig
from datasets import disable_caching, load_dataset, concatenate_datasets
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    HfArgumentParser,
    BitsAndBytesConfig,
)
from trl import SFTTrainer

disable_caching()

logger = logging.getLogger(__name__)

@dataclass
class SFTTrainingArguments:
    model_name_or_path: str
    data_files: List[str]
    eval_data_files: Optional[List[str]] = None
    tokenizer_name_or_path: Optional[str] = None
    use_fast: bool = True
    additional_special_tokens: Optional[List[str]] = None
    max_seq_length: int = 4096
    load_in_8bit: bool = False
    load_in_4bit: bool = False
    use_flash_attention_2: bool = False
    use_peft: bool = False
    peft_target_model: Optional[str] = "llm-jp"
    peft_target_modules: Optional[List[str]] = None
    peft_lora_r: int = 8
    peft_lora_alpha: int = 32
    peft_lora_dropout: float = 0.05
    # W&B 関連の引数を追加
    wandb_project: Optional[str] = None
    wandb_run_name: Optional[str] = None
    wandb_log_steps: Optional[int] = None

    def __post_init__(self):
        if self.load_in_8bit and self.load_in_4bit:
            raise ValueError("load_in_8bit and load_in_4bit are mutually exclusive")
        if self.peft_target_model and self.peft_target_modules is None:
            if self.peft_target_model == "llm-jp":
                self.peft_target_modules = ["c_attn", "c_proj", "c_fc"]
            elif self.peft_target_model == "llama":
                self.peft_target_modules = [
                    "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"
                ]
            elif self.peft_target_model == "llama-all":
                self.peft_target_modules = [
                    "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", "lm_head", "embed_tokens"
                ]
            else:
                logger.warning(
                    f"peft_target_model '{self.peft_target_model}' is not supported, "
                    f"so peft_target_modules is set to None."
                )

    def from_pretrained_kwargs(self, training_args):
        if self.load_in_8bit:
            kwargs = {"load_in_8bit": True}
        elif self.load_in_4bit:
            kwargs = {
                "quantization_config": BitsAndBytesConfig(
                    load_in_4bit=True,
                    bnb_4bit_use_double_quant=True,
                    bnb_4bit_quant_type="nf4",
                    bnb_4bit_compute_dtype=torch.bfloat16,
                )
            }
        elif training_args.bf16:
            kwargs = {"torch_dtype": torch.bfloat16}
        else:
            kwargs = {"torch_dtype": torch.float16}
        kwargs["use_flash_attention_2"] = self.use_flash_attention_2
        return kwargs

def load_datasets(data_files, tokenizer, max_seq_length=2048):
    datasets = []
    for data_file in data_files:
        dataset = load_dataset("json", data_files=data_file)
        dataset = dataset["train"]

        def tokenize_function(example):
            tokenized = tokenizer.apply_chat_template(
                example["messages"],
                tokenize=True,
                add_generation_prompt=False,
                truncation=True,
                max_length=max_seq_length,
                return_tensors="pt",
                return_dict=True
            )
            return {
                "input_ids": tokenized["input_ids"].squeeze(0).tolist(),
                "attention_mask": tokenized["attention_mask"].squeeze(0).tolist()
            }

        dataset = dataset.map(tokenize_function, remove_columns=dataset.column_names)
        datasets.append(dataset)
    return concatenate_datasets(datasets)

def main():
    parser = HfArgumentParser((TrainingArguments, SFTTrainingArguments))
    training_args, sft_training_args = parser.parse_args_into_dataclasses()

    # W&B の設定を適用
    if sft_training_args.wandb_project:
        training_args.report_to = ["wandb"]  # W&B にログを送信
        wandb.init(
            project=sft_training_args.wandb_project,
            name=sft_training_args.wandb_run_name,
            config=vars(training_args),  # TrainingArguments を W&B に記録
        )
        if sft_training_args.wandb_log_steps:
            training_args.logging_steps = sft_training_args.wandb_log_steps

    tokenizer_name_or_path = (
        sft_training_args.tokenizer_name_or_path or sft_training_args.model_name_or_path
    )
    logger.info(f"Loading tokenizer from {tokenizer_name_or_path}")
    tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_name_or_path,
        use_fast=sft_training_args.use_fast,
        additional_special_tokens=sft_training_args.additional_special_tokens,
        trust_remote_code=True,
    )

    chat_template = (
        "{{bos_token}}{% for message in messages %}"
        "{% if message['role'] == 'user' %}{{ '\\n\\n### 指示:\\n' + message['content'] }}"
        "{% elif message['role'] == 'system' %}{{ '以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。' }}"
        "{% elif message['role'] == 'assistant' %}{{ '\\n\\n### 応答:\\n' + message['content'] + eos_token }}"
        "{% endif %}"
        "{% if loop.last and add_generation_prompt %}{{ '\\n\\n### 応答:\\n' }}{% endif %}"
        "{% endfor %}"
    )
    tokenizer.chat_template = chat_template
    logger.info("Custom chat template applied to tokenizer")

    if tokenizer.bos_token is None:
        tokenizer.bos_token = "<|begin_of_text|>"
        logger.info(f"Set default bos_token: {tokenizer.bos_token}")
    if tokenizer.eos_token is None:
        tokenizer.eos_token = "<|end_of_text|>"
        logger.info(f"Set default eos_token: {tokenizer.eos_token}")

    logger.info("Loading data")
    train_dataset = load_datasets(sft_training_args.data_files, tokenizer, sft_training_args.max_seq_length)
    if sft_training_args.eval_data_files:
        eval_dataset = load_datasets(sft_training_args.eval_data_files, tokenizer, sft_training_args.max_seq_length)
        training_args.do_eval = True
    else:
        eval_dataset = None

    logger.info(f"Loading model from {sft_training_args.model_name_or_path}")
    kwargs = sft_training_args.from_pretrained_kwargs(training_args)
    model = AutoModelForCausalLM.from_pretrained(
        sft_training_args.model_name_or_path,
        trust_remote_code=True,
        **kwargs,
    )

    peft_config = None
    if sft_training_args.use_peft:
        logger.info("Setting up LoRA")
        peft_config = LoraConfig(
            r=sft_training_args.peft_lora_r,
            target_modules=sft_training_args.peft_target_modules,
            lora_alpha=sft_training_args.peft_lora_alpha,
            lora_dropout=sft_training_args.peft_lora_dropout,
            fan_in_fan_out=True,
            bias="none",
            task_type="CAUSAL_LM",
        )
        if training_args.gradient_checkpointing:
            for param in model.parameters():
                param.requires_grad = False
                if param.ndim == 1:
                    param.data = param.data.to(torch.float32)
            model.gradient_checkpointing_enable()
            model.enable_input_require_grads()

    logger.info("Setting up trainer")
    trainer = SFTTrainer(
        model=model,
        args=training_args,
        processing_class=tokenizer,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        peft_config=peft_config,
    )

    logger.info("Training")
    trainer.train()

    logger.info("Saving model")
    trainer.save_model()

    # W&B を終了
    if sft_training_args.wandb_project:
        wandb.finish()

if __name__ == "__main__":
    logging.basicConfig(
        level=logging.DEBUG,
        format="%(asctime)s %(name)s:%(lineno)d: %(levelname)s: %(message)s",
    )
    main()

4. データセットの形式

トレーニングデータは、JSONLファイルの形式で提供されます。各行は、以下のような構造のJSONオブジェクトです:

{
  "ID": "magpie_sft_v1.0-0",
  "messages": [
    {"role": "system", "content": "以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。"},
    {"role": "user", "content": "日本の伝統的な文化や風習について、特に地域ごとの違いや特徴を知りたいと思っています。..."},
    {"role": "assistant", "content": "もちろん、その2つの祭りの特徴について詳しく説明いたします。..."}
  ]
}

各メッセージには、「role」と「content」の2つのフィールドがあります。「role」は「system」、「user」、または「assistant」のいずれかです。「content」は、メッセージの内容です。

5. 学習の実行

実行コマンド例

学習を実行する際は、データセットとモデルを指定したパスに格納してください。ここでは、通常のsft(Supervised Fine-Tuning)と、chat_templateを使用したチャット形式のトレーニングの実行コマンドを紹介します。なお、この記事で示したPythonコードはchat_templateを使用したもので、src/train_chat.pyに記載されています。通常のSFT用のコードはsrc/train.pyにあります。

1. 通常のトレーニング

通常のSFTを実行する場合のコマンドです。

python src/train.py \
    --num_train_epochs 1 \
    --per_device_train_batch_size 8 \
    --per_device_eval_batch_size 4 \
    --gradient_accumulation_steps 8 \
    --learning_rate 1e-5 \
    --warmup_ratio 0.1 \
    --lr_scheduler_type cosine \
    --bf16 \
    --max_seq_length 4096 \
    --data_files ./data/sample_data.jsonl \
    --model_name_or_path ./model/sample_model \
    --output_dir results/

2. chat_templateを使ったチャットトレーニング

src/train_chat.pyでは、以下のchat_templateが使用されています。このテンプレートにより、ユーザー、システム、アシスタントの役割に応じた会話形式のデータが生成されます。

chat_template = (
    "{{bos_token}}{% for message in messages %}"
    "{% if message['role'] == 'user' %}{{ '\n\n### 指示:\n' + message['content'] }}"
    "{% elif message['role'] == 'system' %}{{ '以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。' }}"
    "{% elif message['role'] == 'assistant' %}{{ '\n\n### 応答:\n' + message['content'] + eos_token }}"
    "{% endif %}"
    "{% if loop.last and add_generation_prompt %}{{ '\n\n### 応答:\n' }}{% endif %}"
    "{% endfor %}"
)

実行コマンド:

python src/train_chat.py \
    --num_train_epochs 1 \
    --per_device_train_batch_size 8 \
    --per_device_eval_batch_size 4 \
    --gradient_accumulation_steps 8 \
    --learning_rate 5e-5 \
    --warmup_ratio 0.1 \
    --lr_scheduler_type cosine \
    --bf16 \
    --data_files ./data/sample_data.jsonl \
    --model_name_or_path ./model/sample_model \
    --output_dir results/output_model/ \
    --wandb_project "sample-sft" \
    --wandb_run_name "test_1" \
    --wandb_log_steps 10

6.パラメータの解説

基本的なパラメータの解説は以下の記事でしています。

sftでは以下のようなlearnig_rateを使用することが多い
sftにおける学習率

以下のパラメータで指定する

設定項目 説明
learning_rate 学習率を5e-5(0.00005)に設定し、モデルの更新ステップを細かく制御。
warmup_ratio ウォームアップ期間を全体の10%(0.1)に設定し、学習率を徐々に増加させる比率。
lr_scheduler_type cosineスケジューラを指定し、学習率をコサイン関数に基づいて徐々に減少させる方式。
use_peft falseの場合はフルパラメータ学習を行い、trueの場合はLoRA学習を行います。 (デフォルトはfalse)

おわりに

本記事では、trlライブラリを用いたSFT(教師ありファインチューニング)の手法を解説しました。
sftの記事はたくさんありますが、大体がlora学習なためフルパラメータ学習は珍しいと思います。また、小さいモデルだとchat_templateを使用するより精度が上がるのでおすすめです。
また、分からないことがありましたら、気軽にTwitterのDMに質問してください。

開発支援のお願い

現在、開発を続けていますが、クラウドGPUの価格が高く、十分な計算リソースを確保できずにいます。そのため、思い通りに開発が出来ていません。
また、オープンソースの理念を大切にしており、プログラム・データセット・モデルを有料で公開するつもりはありません。そのため、金銭的に余裕のある方に支援していただけると大変助かります。
TwitterのDMやご支援いただける方は、以下のプラットフォームよりお願いいたします。

Discussion

ログインするとコメントできます