🤔

日本語reasoningモデルを作る

2024/12/08に公開

1人ローカルLLMアドベントカレンダーの8日目です。

ローカルLLMやマルチモーダルモデルの学習やデータセット周りについて書いていく予定なので、興味がある方は明日以降も読んでいただけると嬉しいです!
https://qiita.com/advent-calendar/2024/local-llm

要約

  • 一般タスクと数学のCoTデータセットを作成した
  • それらのデータを使って日本語reasoningモデルを学習した

https://huggingface.co/Kendamarron/llm-jp-3-3.7b-o1-v0.1
https://huggingface.co/Kendamarron/Qwen2.5-7B-o1-ja-v0.1

目的

数日前にo1-proがリリースされ、HuggingFaceでも連日新たなreasoningモデルが続々と公開されています。

しかし、これらのモデルに(おそらく)日本語データで学習しているものがないため、日本語reasoningデータの作成とモデルの学習を行っていこうと思います。

実施内容

  1. 一般タスクと数学のCoTデータセットを作成
  2. それらのデータでモデルをSFT

データセットの作成について

一般タスク

一般タスクについてはシステムプロンプトに以下のものを設定し、応答を生成させています。

あなたは優秀で論理的なアシスタントです。
1. まずは<Thought></Thought>タグの中であなたの思考の過程を抜けがないように記載していきます。ユーザーからの指示に対して、ステップごとに詳細を詰めていきます。ここでは最終的にユーザーに提供するべき情報がすべて記載されるべきです。思考を進めていく中で、定期的に過去の思考過程を見直し、内容や方向性の修正を行います。
2. <Output></Output>タグの中に最終的にユーザーに提供する出力を記載します。ユーザーは<Output></Output>の中だけ見ることになるので、回答として必要となる情報はすべて記載されるべきです。 タグは必ず閉じなければなりません。

タグの出力が不安定なため、修正やフィルタリングを行う処理を行っています。

instructionにはAratako/Magpie-Tanuki-8B-annotated-96kをお借りして、約8400件のデータを作成しています。

https://huggingface.co/datasets/Kendamarron/Magpie-Tanuki-8B-CoT

数学タスク

  1. nvidia/OpenMathInstruct-2の問題を翻訳
  2. ステップごとに計算を進めるようにシステムプロンプトを設定して応答を生成
  3. 2の内容からユーザーに対する最終的な応答を生成
  4. 3の答えがnvidia/OpenMathInstruct-2のexpected_answerと一致しているかLLMが判定(答えの表記揺れ対策)し、不正解のレコードを除去

の流れで行っています。

こちらは約15000件のデータを作成しています。

https://huggingface.co/datasets/Kendamarron/OpenMathInstruct-2-ja-CoT

学習

  • llm-jp/llm-jp-3-3.7b-instruct
  • Qwen/Qwen2.5-7B-Instruct

の2つのモデルに対してフルパラメータで2epoch学習を行いました。

https://huggingface.co/Kendamarron/llm-jp-3-3.7b-o1-v0.1
https://huggingface.co/Kendamarron/Qwen2.5-7B-o1-ja-v0.1

結果

出力

ここでは、Qwen2.5-7B-o1-ja-v0.1の出力を見ていきます。

1から10までの整数を足すと?

<Thought>
1から10までの整数を足す問題を解くために、まずこの数列を確認します。

1, 2, 3, 4, 5, 6, 7, 8, 9, 10

この数列は等差数列であり、初項が1、末項が10、項数が10です。

等差数列の和の公式は以下の通りです。
[ S_n = \frac{n}{2} (a + l) ]
ここで、

  • ( S_n ) は数列の和
  • ( n ) は項数
  • ( a ) は初項
  • ( l ) は末項

この問題に適用すると、

  • ( n = 10 )
  • ( a = 1 )
  • ( l = 10 )

これらを公式に代入します。
[ S_{10} = \frac{10}{2} (1 + 10) ]
[ S_{10} = 5 \times 11 ]
[ S_{10} = 55 ]

したがって、1から10までの整数を足すと55になります。
</Thought>

<Output>
1から10までの整数を足すと55になります。
</Output>

猫を飼育するために気をつけるべきことを3つ教えてください。

<Thought>
まず、猫を飼育するために気をつけるべきことについて考える必要があります。猫は独立心が強く、自立した動物ですが、それでもいくつかの重要な点を考慮する必要があります。 1. 適切な食事: 猫は肉食動物であり、バランスの取れた食事を摂る必要があります。高品質の猫用フードを選び、適切な量を定期的に与えることが重要です。また、水の供給も十分に行う必要があります。 2. 適切な住居環境: 猫は静かで落ち着いた環境を好みます。安全で清潔なスペースを提供し、猫が自由に移動できるようにすることが重要です。また、猫用のトイレや寝床、遊具なども揃えておくと良いでしょう。 3. 健康管理: 定期的な健康チェックや予防接種、駆虫が必要です。また、猫の健康状態を定期的に確認し、異常が見られた場合には早めに獣医師に相談することが重要です。
</Thought>

<Output>
猫を飼育するために気をつけるべきこととして、以下の3点が挙げられます。まず、高品質の猫用フードを選び、適切な量を定期的に与えることが重要です。次に、安全で清潔な住居環境を提供し、猫が自由に移動できるようにすることが必要です。最後に、定期的な健康チェックや予防接種、駆虫を行い、猫の健康状態を常に確認することが大切です。
</Output>

割といい感じにreasoningできていそうです!

ベンチマーク

llm-jp-evalのmgsm(計算タスク)を使って評価を行っていきます。

testデータは250件で、reasoning能力を測る意味も込めて、出力フォーマットの指定はせず、応答分の最後に登場した数値が答えと一致しているかを正解の基準にしています。

モデル 正解数(全250問)
llm-jp/llm-jp-3-3.7b-instruct 23
Kendamarron/llm-jp-3-3.7b-o1-v0.1 48
Qwen/Qwen2.5-7B-Instruct 177
Kendamarron/Qwen2.5-7B-o1-ja-v0.1 181

llm-jp-3-3.7b-instructについては正解数が倍増しています。これには単純に数学のデータで学習させたことが影響しているかもしれません。

Qwen2.5-7B-Instructは元々のモデルがかなり丁寧にチューニングされていることもあり、正解数は誤差程度の差しかありませんでした。

まとめ

今回は日本語データでreasoningモデルを作ってみました。

まだデータが少なく精度に課題はありますが、もう少しデータを増やしてから作り直そうと思います。

Discussion