📖

LLMのfine-tuningでカスタム指標での学習評価を行う

に公開

概要

大規模言語モデル (LLM) のファインチューニングでは、TrainerSFTTrainer 等を使って学習を行います。通常では、クロスエントロピー損失を最小化することを目指しますが、翻訳やコード生成などタスクによっては、BLEUやROUGEなどのカスタム指標や評価関数を定義することで、モデルの出力をより期待に近づけることが可能です。

今回の記事では、compute_metrics を利用して、BLEUスコアとROUGEスコアを評価指標として追加するサンプルコードを忘備録として紹介します。

データセットなどはGeminiなどに頼んでもらって作ったので、適当になっています。

サンプルコード (Google Colabで実行を確認済)

from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, BitsAndBytesConfig, EarlyStoppingCallback
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, TaskType
import numpy as np
import evaluate
import torch
import os
import re

# 評価時のテキストの正規化
def normalize_config(txt: str, sort_lines: bool = False) -> str:
  lines = [ln for ln in  txt.splitlines() if not ln.lstrip().startswith("#")]
  lines = [re.sub(r"\s+"," ",ln).strip() for ln in lines if ln.strip()]
  if sort_lines:
    lines.sort()
  return "\n".join(lines)

# --- モデル・トークナイザ ---
model_name = "meta-llama/Meta-Llama-3.1-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Set pad token
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    
# Configure 4-bit quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4", # or "fp4"
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=False,
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype="auto",
    quantization_config=bnb_config, 
)

# --- LoRA 設定 ---
lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=32,
    lora_alpha=64,
    lora_dropout=0.1,
    target_modules = ["q_proj", "k_proj","lm_head","embed_tokens"]
)
model = get_peft_model(model, lora_config)

# --- データセット (GSM8K - 小学校レベルの算数問題) ---
dataset = load_dataset("gsm8k", "main", split="train[:500]")  # 最初の500サンプルを使用

def preprocess_function(examples):
    # GSM8Kは question と answer のフィールドを持つ
    # 簡単な指示形式に変換
    inputs = []
    targets = []
    
    for question, answer in zip(examples["question"], examples["answer"]):
        # 入力: 問題文を簡潔な指示形式に
        input_text = f"Solve this math problem: {question}"
        # 出力: 答えの部分(#### 以降が最終回答)
        # GSM8Kの答えは説明文と最終回答を含むので、短くして学習しやすくする
        if "####" in answer:
            # 最終回答部分を抽出
            final_answer = answer.split("####")[1].strip()
            target_text = f"The answer is {final_answer}"
        else:
            target_text = answer[:200]  # 長すぎる場合は切り詰め
        
        inputs.append(input_text)
        targets.append(target_text)

    model_inputs = tokenizer(inputs, max_length=128, truncation=True, padding="max_length")
    labels = tokenizer(targets, max_length=128, truncation=True, padding="max_length").input_ids

    # Align labels with input padding and set padding tokens to -100
    model_inputs["labels"] = [
        [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels
    ]

    return model_inputs

tokenized_dataset = dataset.map(preprocess_function, batched=True)

# evalデータセット。今回は動作を見るために取得幅を設定。
eval_dataset_subset = tokenized_dataset.select(range(500)) 

# --- BLEU, ROUGE 計算 ---
bleu = evaluate.load("bleu")
rouge = evaluate.load("rouge")

# 評価計算の前処理
def preprocess_logits_for_metrics(logits, labels):
    if isinstance(logits,tuple):
        logits = logits[0]
    return logits.argmax(dim=-1)

# カスタム指標の処理
def compute_metrics(eval_pred):
    predictions, labels = eval_pred

    if isinstance(predictions, tuple):
        predictions = predictions[0]
    if predictions.ndim == 3:
        predictions = np.argmax(predictions,axis=-1)

    # loss計算で使う部分だけ残す (-100の範囲は利用しない)
    true_preds = np.where(labels != -100, predictions, tokenizer.pad_token_id)
    true_labels = np.where(labels != -100, labels, tokenizer.pad_token_id)

    decoded_preds = tokenizer.batch_decode(true_preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(true_labels, skip_special_tokens=True)

    # 正規化
    decoded_preds = [normalize_config(t,sort_lines=True) for t in decoded_preds]
    decoded_labels = [normalize_config(t,sort_lines=True) for t in decoded_labels]
    
    # BLEU
    bleu_result = bleu.compute(predictions=decoded_preds, references=decoded_labels)

    #ROUGE
    rouge_result = rouge.compute(predictions=decoded_preds, references=decoded_labels)
    return {"bleu": bleu_result["bleu"],
            "rouge": rouge_result["rougeL"]}

# --- 学習設定 ---
training_args = TrainingArguments(
    output_dir="./results",
    fp16=True,
    save_steps=50,
    eval_strategy="steps",
    eval_steps=50,
    max_steps=300,
    logging_steps=50,
    learning_rate=2e-4,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    num_train_epochs=1,
    weight_decay=0.01,
    logging_dir="./logs",
    metric_for_best_model="bleu",# compute_metricsで設定した指標を指定
    greater_is_better=True, # Bleuのような指標の場合は、数値が大きいほうが評価が高いためTrueに設定する
    report_to=None
)

# --- Trainer ---
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    eval_dataset=eval_dataset_subset, 
    tokenizer=tokenizer,
    compute_metrics=compute_metrics, # 作成したcompute_metricsを設定
    preprocess_logits_for_metrics=preprocess_logits_for_metrics, # loss計算前の前処理を行う
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)], # early_stoppingの設定
)

trainer.train()

実行結果例は以下の通りになります。

Step Training Loss Validation Loss Bleu Rouge
50 3.675800 1.851547 0.168976 0.669293
100 1.064900 1.166520 0.178154 0.750333
150 0.927200 1.179824 0.196318 0.750300
200 0.996300 1.130214 0.156965 0.751700
250 0.935200 1.160782 0.224630 0.751833
300 1.041300 1.092885 0.223699 0.751367

Compute_metrics

Trainer には compute_metrics と呼ばれる引数が存在します。
ここにカスタム指標を計算する関数を設定することで、カスタム指標で学習の評価を行ってくれます。

今回はcompute_metricsという関数でBlueRougeの計算を行う処理を実装しています。
複数の指標を設定することが可能です。
指標を追加したい場合は、その計算を行う処理と返り値にその指標のキーと値を設定すれば、それを使って評価を行うことができます。

# カスタム指標の処理
def compute_metrics(eval_pred):
    predictions, labels = eval_pred

    if isinstance(predictions, tuple):
        predictions = predictions[0]
    if predictions.ndim == 3:
        predictions = np.argmax(predictions,axis=-1)

    # loss計算で使う部分だけ残す (-100の範囲は利用しない)
    true_preds = np.where(labels != -100, predictions, tokenizer.pad_token_id)
    true_labels = np.where(labels != -100, labels, tokenizer.pad_token_id)

    decoded_preds = tokenizer.batch_decode(true_preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(true_labels, skip_special_tokens=True)

    # 正規化
    decoded_preds = [normalize_config(t,sort_lines=True) for t in decoded_preds]
    decoded_labels = [normalize_config(t,sort_lines=True) for t in decoded_labels]
    
    # BLEU
    bleu_result = bleu.compute(predictions=decoded_preds, references=decoded_labels)

    #ROUGE
    rouge_result = rouge.compute(predictions=decoded_preds, references=decoded_labels)

    # 評価指標を返す ("指標名":値を追加すれば良い)
    return {"bleu": bleu_result["bleu"],
            "rouge": rouge_result["rougeL"]}

後は、

  1. TrainingArgumentmetric_for_best_model に作成した指標 (ここではbleuに設定) の表示名を設定
  2. Trainerにて、compute_metrics に作成した関数を設定

と設定すれば、独自の指標で学習の評価を行ってくれます。
(他にもGPUのメモリ利用率を減らすために、評価前にデータの前処理を行う関数 preprocess_logits_for_metrics を設定したり、指標の値の変化が見られなくなったら学習をストップする early_stopping を行うためにcallbacks を設定したりしています。)

# --- 学習設定 ---
training_args = TrainingArguments(
    output_dir="./results",
    fp16=True,
    save_steps=50,
    eval_strategy="steps",
    eval_steps=50,
    max_steps=300,
    logging_steps=50,
    learning_rate=2e-4,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    num_train_epochs=1,
    weight_decay=0.01,
    logging_dir="./logs",
    metric_for_best_model="bleu",# compute_metricsで設定した指標を指定
    greater_is_better=True, # Bleuのような指標の場合は、数値が大きいほうが評価が高いためTrueに設定する
    report_to=None

# --- Trainer ---
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    eval_dataset=eval_dataset_subset, 
    tokenizer=tokenizer,
    compute_metrics=compute_metrics, # 作成したcompute_metricsを設定
    preprocess_logits_for_metrics=preprocess_logits_for_metrics, # loss計算前の前処理を行う
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)], # early_stoppingの設定
)
)

参考文献

Discussion