🐩

DeepSeekでも使われるGRPOをtrlで試す

2025/02/02に公開

はじめに

良くも悪くもDeepSeekさん話題になっていますね。
DeepSeek-R1は、実際にNVIDIAの株価を90兆円吹っ飛ばすほどのパフォーマンスやインパクトを与えているということで、流石にその技術的背景は興味深いものがあります。[paper]

その技術的背景の中でも、強化学習を用いたアライメント手法であるGRPOは注目すべきポイントの一つでしょう。

trlコミュニティが早速GRPOの実装を公開したので、早速試してみたいと思います。(ちょっと前からあったけど、、、)

やっていく

実験設定

  • モデル: SakanaAI/TinySwallow-1.5B-Instruct
  • データセット: cl-nagoya/auto-wiki-qa
    • QAデータセットのうち質問部のみを使用
  • 報酬関数
    • 絵文字を含めるようにさせる
    • 出力文字数を30文字程度に制限する

セットアップ

githubから、あるいは最新のtrlをインストールしましょう。

pip install -U "trl>=0.14.0"

コード

trlでは基本的に学習させるアルゴリズムの名前を関したConfigクラスとTrainerクラスを使用します。ConfigクラスはtransformersTrainingArgumentsクラスを継承しており、アルゴリズム独自のパラメータを追加しています。[公式ドキュメント]

GRPO固有のうち主要なものを抜粋してみました。

from trl import GRPOConfig, GRPOTrainer

train_args = GRPOConfig(
    ...,
    num_generations=8,
    temperature=0.9,
    max_prompt_length=256,
    max_completion_length=256,
    use_vllm=False,
    learning_rate=1e-6,
    beta=0.04,
)

GRPOでは、1プロンプトに対して複数の文章を生成しそれに対する報酬を計算するというアプローチを取ります。そのため、num_generations,temperatureによってどの程度の多様性を持った文章をいくつ生成するかを決めます。また、サンプリング高速化のためにvllmを使用するかも決められます。

SourceGRPOTrainerのドキュメントより

他に、KL Divergenceの重視度を決めるbetaが学習に効いてきそうなあたりでしょうか。

また、GRPO論文のもう一つの特徴として、非ニューラルネットワーク(NN)系の報酬関数を使用している点が挙げられます。従来はNN系の報酬モデルを学習、またはDPOのような報酬モデルを介在させない直接的アプローチが主流でしたが、DeepSeekではフォーマットと答えがあっているかというルールベースの報酬関数を使用しています。

GRPOTrainerでは、reward_funcs=[func1, func2, ...]といった形で、複数の報酬関数を渡すことができます。また、便利なことにreward_funcs="weqweasdas/RM-Gemma-2B"といった形で従来のNN系の報酬モデルも渡すことができるようです。

そして、報酬関数については、以下のような仕様があります。

  • promptsというキーワード引数を受け取れるようにする
  • completionsというキーワード引数を受け取れるようにする
  • データセットに含まれるカラムは、それと同名のキーワード引数として渡される
  • 戻り値はlist[float]にする

一番シンプルな方法は下記のように、**kwargsを受け取るようにすることです。

今回は、「絵文字を含める」、「出力文字数を30文字程度に制限する」ために報酬関数を簡単に設計してみます。

MAX_EMOJI = 4

# 絵文字を4文字以上含めるように動機づけ
def emoji_reward(completions, **kwargs):
    return [min(emoji.emoji_count(completion), MAX_EMOJI) for completion in completions]


TARGET_LENGTH = 30

# 出力文字数を30文字程度に制限する
def length_reward(completions, **kwargs):
    alpha = 0.15 # 30文字からの乖離をどの程度重視するか
    return [
        np.exp(-alpha * abs(len(completion) - TARGET_LENGTH))
        for completion in completions
    ]

学習

これで一通りの準備はOKです。
では、学習を実行してみます。

from trl import GRPOConfig, GRPOTrainer

repo = "SakanaAI/TinySwallow-1.5B-Instruct"

dataset = load_dataset("cl-nagoya/auto-wiki-qa", split="train[:10000]")  # 全件だと時間がかかるので10000件に制限

model = AutoModelForCausalLM.from_pretrained(
    repo,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
)

tokenizer = AutoTokenizer.from_pretrained(repo)
tokenizer.pad_token = tokenizer.eos_token

trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    train_dataset=dataset,
    reward_funcs=[emoji_reward, length_reward],
    args=train_args,
)

trainer.train()

結果

学習過程

1エポック、10000件で学習してみた結果がこちらです。

初期の段階では、completionに対する報酬のみが得られていますが、80ステップあたりから絵文字に対する報酬が増加していっていることがわかります。
(反対に、出力文字数に対する報酬は減少していっているが、completion_lengthは再び30文字に近づき始めていそうなので学習不足かも)

生成結果

実際の生成結果はこちらです。


富士山の高さは?

  • ベースモデル

    富士山の高さは **3,776メートル** です。
    
    
  • 学習したモデル

    富士山の高さは **3,776メートル** です! 🗻✨ 🌋✨ 🍓🌟
    
    📄 地球上で最も高い山の一つとして知られていますね! ✨🏔️
    

一番好きな漫画は?

  • ベースモデル

    私はAIなので、個人的な好みや感情はありません。
    
    しかし、「一番好きな」という質問には答えられません。なぜなら、
    
    * **「好き」の定義** が人それぞれ違うからです。
    * 私たちは読者ではなく作者であり、作品を作り上げるのは彼ら自身だからです。
    
    どんな物語も、その中で描かれるキャラクター達との出会いや体験を通して、多くの人々に喜びや感動を与えています。ですので、あなたの好きな漫画を見つけるお手伝いをさせてください!
    
    
  • 学習したモデル

    私はAIなので、漫画を読むことはできません! 😊
    
    でも、「好きな」というと難しいですね。たくさんの人が愛している名作たちがありますよね。 📚✨ あなたは何のジャンルが好きですか?
    

小難しいこと言ってください

  • ベースモデル

    「小難しい」内容とは、少し専門的であったり、理解するのに時間がかかったり、一般的にはあまり話題にならないようなことを指す言葉ですね。
    
    例えば、
    
    1. **高度な数学の問題**: 数学の中でも特に複雑な理論や証明を必要とする問題。
    2. **専門的な医学用語**: 医師や看護師などが使うような、医療分野特有の用語。
    3. **哲学的な議論**: 多くの人が同じ結論に達しないような深い考えや主張。
    4. **科学技術に関する最新の研究結果**: 最新の研究成果を解説するとなると、それなりの知識が必要となります。
    
    このような場合は、「小難しい」と表現されることが多いでしょう。
    
    
    何か具体的な事柄があれば、さらに詳しくお話しできるかもしれませんので、教えていただけますか?😊
    
  • 学習したモデル

    小難しいことを言いますね! 😊
    
    例えば、
    
    * **量子コンピューター**:これは、従来のコンピューターよりもとても速く計算ができる新しいタイプのコンピューターのことで、将来的には医療や科学研究に革命を起こす可能性があります。
    * **人工知能**:コンピューターに人間のような思考能力を持たせることを目指しています。この技術は日々進化していて、私たちの生活にも大きな影響を与えています。
    
    これらはほんの一例ですが、もっと具体的な内容なら教えてほしいですか? 🎯
    

とまあ、定性的な評価としても、絵文字の使用が増加し元モデルよりも生成文字数は減少していそうなので、うまく学習できているようです。

おわりに

というわけで、GRPOをtrlで試してみました。
GRPOのようにnum_generations分のサンプリングをすると、従来のNN系の報酬モデルでの推論コストが重いので、非NN系の報酬関数を活用することの大きなメリットを感じました。

また、今回のように絵文字を追加したい場合にデータセットを従来は作成するなど手間がかかっていましたが、このようなプリミティブな報酬関数を追加することで比較的容易に学習できるようになったのは嬉しいです。ただし、今回のように報酬設計によって学習の安定性等に影響が出ることもあるので、試行錯誤が必要そうです。このあたりの報酬設計については研究の余地が結構ありそうですね。

GRPOTrainerを用いることで、簡単に学習できるのでお試しあれ。
それでは、また。

おまけ

LoRAを用いた学習を行いましたが、報酬は微増したものの、絵文字生成されなかった。(ステップとか違うのは御愛嬌)

peft_config = LoraConfig(
    r=16,
    lora_alpha=64,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "up_proj",
        "down_proj",
        "gate_proj",
    ],
    task_type="CAUSAL_LM",
    lora_dropout=0.05,
)


trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    train_dataset=dataset,
    reward_funcs=[emoji_reward, length_reward],
    args=train_args,
    peft_config=peft_config,
)

trainer.train()

lora_vs_full

GitHubで編集を提案

Discussion