Open1

GRPO を用いた推論のための LLM のトレーニング後

MygMyg

https://huggingface.co/learn/cookbook/en/fine_tuning_llm_grpo_trl#4-check-the-model-performance

Hugging Face 公式 Open-Source AI Cookbook の 「Post training an LLM for reasoning with GRPO in TRL」 をベースに、GRPO(Group Relative Policy Optimization)TRL で用いて推論系タスク向けに RL ポストトレーニング する手順を、日本語で実務視点に整理したものです。Colab などの制約環境でも最小例を動かし、次にどこを拡張すべきかまでを示します。

目的:<think>…</think> / <answer>…</answer> 形式を守らせつつ、正答への誘導を行う 形式報酬+正解度報酬 を用いた GRPO 学習の最小構成を理解する。


概要(何ができる?)

  • 小型の Qwen2-0.5B-Instruct をベースに LoRA で効率学習。
  • 数学系の推論データセット AI-MO/NuminaMath-TIR を 5% サブセットで使用。
  • 2 種の報酬関数(フォーマット遵守 / 正解判定)で GRPO 学習を 1 epoch だけ実行。
  • 評価では タグ出力は正しく学習する一方、解答は未正解という結果。最小構成でも挙動は確認でき、ここからモデル規模・データ量・学習反復・報酬設計を強化していくのが次の一手。

背景:GRPO とは

GRPO(Group Relative Policy Optimization) は、PPO から 価値関数(value model)を省いた変種で、各プロンプトに対して G 本の生成を行い、グループ内比較に基づく相対的 advantage を用いて方策を更新します。オンライン学習(生成→報酬→更新を繰り返す)で、推論タスクにおける長い思考(test-time compute のスケール)を促すのに適しています。

  • 代表的なポイント

    • 価値関数が不要で メモリ効率が高い。
    • std で正規化した相対 advantage を用いる実装が一般的だが、Dr.GRPO など近年の研究では長さバイアス低減のための選択肢も提案されている。
    • KL 正則化 β は最近のオープン研究では 0(無効) とする設定も多い。必要に応じて有効化可能。

データセットとプロンプト整形

  • 使用データセット:AI-MO/NuminaMath-TIRproblem, solution, messages を含む数学推論データ)。
  • DeepSeek-R1 の手順に倣い、system で以下の方針を与える:
A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
The assistant first thinks about the reasoning process in the mind and then provides the user
with the answer. The reasoning process and answer are enclosed within <think> </think> and
<answer> </answer> tags, respectively, i.e., <think> ... </think><answer> ... </answer>
User と Assistant の会話です。ユーザーが質問し、Assistant がそれを解きます。
Assistant はまず頭の中で推論過程を考え、その後ユーザーに解答を提示します。
推論過程と解答はそれぞれ <think> </think> と <answer> </answer> のタグで囲みます。すなわち、<think> ... </think><answer> ... </answer> の形式です。
  • データは prompt(system + user)と solution を残し、評価時の正解判定に用います。Cookbook の例では学習サンプル数は 3,622 行(train の 5%)です。

依存関係のセットアップ

最小例では以下がテスト済みバージョンとして示されています。

pip install -U trl peft math_verify
# Tested with:
# transformers==4.47.1, trl==0.14.0, datasets==3.2.0, peft==0.14.0, accelerate==1.2.1, math_verify==0.3.3

Hugging Face Hub への push を行う場合は notebook_login() を実行します。


ベースモデルと LoRA 設定

  • ベース:Qwen/Qwen2-0.5B-Instruct
  • LoRA 設定例:r=8, alpha=32, dropout=0.1, target_modules=["q_proj","v_proj"]
  • 可変パラメータ数の目安:540,672(全体の ≒0.109%)
from transformers import AutoModelForCausalLM
from peft import LoraConfig, get_peft_model

model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2-0.5B-Instruct",
    torch_dtype="auto",
    device_map="auto",
)

lora_config = LoraConfig(
    task_type="CAUSAL_LM",
    r=8, lora_alpha=32, lora_dropout=0.1,
    target_modules=["q_proj","v_proj"],
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

報酬関数(フォーマット+正解度)

1) フォーマット報酬

<think>…</think><answer>…</answer>両方が揃っているかを正規表現でチェックし、満たせば 1、そうでなければ 0。

import re

def format_reward(completions, **kwargs):
    pattern = r"^<think>.*?</think>\s*<answer>.*?</answer>$"
    contents = [c[0]["content"] for c in completions]
    return [1.0 if re.match(pattern, x) else 0.0 for x in contents]

2) 正解度報酬

math_verify を使い、solution と生成結果から抽出した数学表現を比較して 正解=1.0 / それ以外=0.0 を返す簡易採点。

from math_verify import LatexExtractionConfig, parse, verify

def accuracy_reward(completions, **kwargs):
    solutions = kwargs["solution"]
    contents = [c[0]["content"] for c in completions]
    rewards = []
    for content, sol in zip(contents, solutions):
        gold = parse(sol, extraction_mode="first_match",
                     extraction_config=[LatexExtractionConfig()])
        pred = parse(content, extraction_mode="first_match",
                     extraction_config=[LatexExtractionConfig()])
        try:
            rewards.append(float(verify(pred, gold)) if len(gold) else 1.0)
        except Exception:
            rewards.append(0.0)
    return rewards

GRPO の主要ハイパーパラメータ

Cookbook の最小例では、学習を軽くするために以下を絞っています。まずはこの設定で挙動を確認し、その後に段階的に伸ばします。

from trl import GRPOConfig

training_args = GRPOConfig(
    output_dir="Qwen2-0.5B-GRPO-test",
    learning_rate=1e-5,
    remove_unused_columns=False,   # reward で solution にアクセス
    gradient_accumulation_steps=16,
    num_train_epochs=1,
    bf16=True,
    # 生成に関わるパラメータ(まずは小さく)
    max_completion_length=64,  # default: 256
    num_generations=4,         # default: 8
    max_prompt_length=128,     # default: 512
    # ロギング/保存
    report_to=["tensorboard"],
    logging_steps=10,
    push_to_hub=True,
    save_strategy="steps",
    save_steps=10,
)

生成長(max_completion_length)と Gnum_generations)は 計算コストに直結します。まずは小さく、挙動を確認してから増やすのがおすすめです。


学習の実行

from trl import GRPOTrainer

trainer = GRPOTrainer(
    model=model,
    reward_funcs=[format_reward, accuracy_reward],
    args=training_args,
    train_dataset=train_dataset,
)

trainer.train()
trainer.save_model(training_args.output_dir)
trainer.push_to_hub(dataset_name="AI-MO/NuminaMath-TIR")

TensorBoard では、報酬や損失の推移を確認できます。


評価(最小例の結果)

保存済みの学習済みモデル sergiopaniego/Qwen2-0.5B-GRPO を読み込み、テストの一問で推論します。結果の要点:

  • <think>/<answer> のタグは正しく出力できるように。
  • ただし 解答は誤り(最小設定・小型モデル・短時間学習・難易度高めのデータ、という前提)。
  • 参考値:推論時間 ≈ 2.09 秒、生成トークン数 55

うまくいかないときのチェックリスト

  • Tokenizer の padding_side は "left":TRL のオンライン生成では左パディング推奨です。tokenizer.pad_token も設定されているか確認。
  • 生成長と G を上げすぎていないか:OOM/極端な遅さの原因になりがち。まずは短く。
  • 報酬関数の安定性:例外(try/except)で 0 を返すなど、ロバストに。NaN が出ていないか。
  • scale_rewards の検討:質問難度バイアスが疑われる場合は無効化(scale_rewards=False)も試す。
  • 長さバイアス対策loss_type="dr_grpo" や、mask_truncated_completions=True を試す。
  • 生成高速化pip install trl[vllm]vLLM 併用。生成がボトルネックになりやすい。

精度を上げるための次の一手

  1. モデルを大きく:0.5B → 1.5B/3B/7B…。

  2. データを増やす:サブセットからフルデータへ。難易度を段階化する カリキュラム学習 も有効。

  3. 学習反復を増やす:エポック数・イテレーション数を増やす。

  4. 報酬の改善

    • 形式+正解に加え、ステップ整合性(推論過程の検証)自己検証ユニットテスト型(コード) などを導入。
    • マルチタスク報酬(数理+コードなど)で汎用化。
  5. vLLM 併用で生成高速化、num_generations を増やして探索力を付与。

  6. Dr.GRPO / DAPO の損失を検討し、長さバイアスを抑制。


参考リンク

  • Hugging Face Cookbook: Post training an LLM for reasoning with GRPO in TRL
  • TRL ドキュメント:GRPOTrainer(理論・ハイパラ・Dr.GRPO/DAPO)
  • Open-R1(Hugging Face Science による再現プロジェクト)
  • Phil Schmid の mini-R1(Countdown Game) 再現記事
  • The Illustrated DeepSeek-R1(Jay Alammar)
  • DeepSeekMath / DeepSeek-R1 論文・リポジトリ

本記事は Cookbook の内容を日本語で再構成し、実務で詰まりやすいポイント(トークナイザ、生成長、報酬設計、速度最適化)を追記しました。最小例で挙動を掴んだら、モデル規模・データ量・生成回数 G・報酬多様化 を軸にスケールさせていきましょう。