Tanuki-8B, 8x8B - 事後学習の軌跡
GENIAC 松尾研LLM開発プロジェクトメンバーのSomeyaです。
Team「たぬき」Phase 2では、日本語対話能力に優れた大規模言語モデルTanuki-8BとTanuki-8x8Bを開発しました。本記事では、このプロジェクトにおいて、モデルの性能を向上させるために実施した事後学習の取り組みをまとめます。
はじめに
本プロジェクトでは、事後学習の手法として主にSFTとDPOに取り組み、事前学習データの量や品質、回答形式、学習条件など、さまざまな要素がモデル性能に与える影響を検討しました。事後学習に取り組んだ1か月間において、試行錯誤の回数は300回を超えています。事後学習の検討を始めた初期は、日本語対話性能を測るJapanese MT-Benchで3点台だったスコアが、最終的には7点程にまで上昇し、大幅な性能向上を実現しました。3点台のTanukiがどのようにして7点のTanukiへと成長したのか、その過程を振り返り、得られた知見をご紹介します。
※時系列で追っていくので、長くなります。この取り組みで分かったことを手短に把握したい方は得られた知見のセクションをご覧ください。
前書き
本題に入る前に、このプロジェクトで用いた事後学習の手法、取り組みの方針、評価方法等について簡単にご説明します。
Supervised Fine-Tuning (SFT) の役割
SFTは、事前学習後のベースモデルに対して実施する教師あり学習です。SFTでは、特定の入力(指示)と期待される出力(適切な応答)のペアを用いて、指示に対して適切な応答を生成するようにモデルを学習します。多様な指示に対する応答を学習することで、モデルは汎用的な指示理解能力を獲得し、より柔軟な対話が可能になります。
SFTの効果を理解するために、ベースモデルとSFT後のモデルの出力例を比較してみましょう。
ベースモデルの応答は、質問を単に繰り返すだけの無意味な出力となっています。一方、SFT後のモデルの応答は、質問に対して具体的で適切な情報を提供しており、指示された200字以内という制限も守っています。
ベースモデルの段階では、指示に対する応答方法を理解できず、質問文を繰り返したり、支離滅裂な応答をしてしまうことがよくあります。SFTを行うことで、ユーザーの指示に対してより的確に応答できるようになります。
SFTの訓練では事前学習と同様にNext Token Predictionを行います。ただし、SFTでは特に回答部分の生成を調整するのが目的であり、LLMに入力された文章(指示+応答)のうち、応答部分のみでロスを計算することが一般的です。本プロジェクトでも主にこの手法でSFTを行いました。
Direct Preference Optimization (DPO) の役割
DPOは、人間の選好に基づいてモデルを最適化する手法です。SFT後のモデルに対してDPOを実施することで、より自然で人間の期待に沿った応答を生成するモデルを作成することができます。
DPOの学習データは通常、以下の形式で準備します。
- prompt:ユーザーからの質問や指示
- chosen:好ましい(LLMに生成させたい)応答
- rejected:好ましくない(LLMに生成させたくない)応答
以下は、DPOデータの例です。
SFTとの主な違いは、好ましくない応答の生成を抑制するように、負のフィードバックを行う点です。例えば、表中の例1のように、詳細で丁寧な応答をchosenとし、簡潔な応答をrejectedとしたデータでDPOを行うと、より詳細な応答をするモデルになると考えられます。また、例2のように、共感的で具体的なアドバイスを含む応答をchosen、感情を考慮しない不親切な応答をrejectedとすることで、より親身に寄り添う会話ができるモデルを作ることができます。
また、DPOデータの生成に学習対象のモデルを使用することで、モデル自身の弱点や傾向が直接的に改善されることが期待されます。例えば、以下のような方法があります。
- rejetedの生成に学習対象のモデルを使用し、より高性能なモデルでchosenを生成する
- 学習対象のモデルで複数の応答を生成し、モデルや人手で順位付けする(順位が高い応答をchosen、低い応答をrejectedとする)
本プロジェクトでは、最終的に、開発中のTanukiモデルをDPOデータの生成に使用することで、上記2つの学習方法を取り入れました。
学習コード
本プロジェクトでは、効率的に事後学習を試行するために、sftlabとpolabという独自のツールを開発し、使用しました。
sftlab・polabの特徴は以下のとおりです。
- sftlabは、llm-jp-sftをベースとしたプロジェクトで、TRLのSFTTrainerを使用しています。
- polabは、複数のPreference Optimization系の学習手法(DPO, ROPO etc.)に対応しています。基本的な使用方法はsftlabと同様です。
- どちらも学習コードを変更せずに、configファイルからモデル、データ、パラメータを設定できます。
- 主要なデータセット形式に対応する前処理関数が実装されており、huggingface上の様々なデータ形式に柔軟に対応することができます。
- WandBの保存先をディレクトリ構造と紐付けて自動で作成し、ログ管理を容易にしています。
- 実験に使ったconfigやコードをWandBに保存し、実験の再現性を高めています。
sftlab・polabを使用することで、効率的な実験サイクルとチーム全体での知見の蓄積・共有が可能になりました。これが、短期間で大幅な性能向上を実現できた要因の一つと言えるでしょう。
評価方法
Japanese MT-Bench (JMT-Bench)
Tanukiの性能を評価するために、JMT-Benchのスコアを指標として使用しました。JMT-Benchは、日本語におけるLLMの応答能力や作文性能を評価するためのベンチマークです。
JMT-Benchでは、以下の8つのカテゴリにわたる問題を用いて、モデルの性能を総合的に評価します。
このような問題をモデルに回答させ、その回答を高性能なLLM(gpt-4など)で評価する手法(LLM-as-a-Judge)が一般的です。本プロジェクトでは、Nejumi LLMリーダーボード Neoを改変して使用し、gpt-4oによる10段階評価を行いました。Nejumi LLMリーダーボード Neoではデフォルトでgpt-4を使用しますが、gpt-4oの方がapi料金が安価なため、評価コストを下げる(= 試行回数を増やす)ためにgpt-4oを採用しました。
また、JMT-Benchの特徴として、全問題がマルチターン形式となっています。2ターン目の応答生成時には、1ターン目の質問・応答と2ターン目の質問がモデルに入力されます。
定性評価
JMT-Benchスコアを使用すると定量的にスコアの傾向を見ることができますが、gpt-4レベルのモデルに評価させても、時折不適切な判定が起こることがあります。人間の目でもJMT-Benchの出力を確認し、適切に応答できているか確認しました。
学習対象のモデル
8Bモデルとそれをベースにして構築した8x8Bモデルの2種類を学習しました。これらのモデルの関係は以下の通りです。
- 8Bモデル: Phase1で学習させた8Bモデルをベースに継続事前学習を実施したもの
- 8x8Bモデル: ある程度学習させた8Bモデルを8つに複製してMixture of Experts (MoE) を構築し、継続事前学習を実施したもの
両モデルにはそれぞれ長所があり、どちらを重視するかは当初から決定していませんでした。8Bモデルは学習速度が速く、試行錯誤を重ねやすいという利点があります。一方、8x8Bモデルはパラメータ数が多く、スケーリング則に従うとより高い性能が期待できます。プロジェクトの最優先事項は、モデルの種類や学習手法を問わず、とにかく高性能なLLMを作ることでした。そのため、両方のモデルの長所を活かしながら並行して検討を進め、ポテンシャルを見極めながら方向性を決めていきました。
8Bよりも8x8Bの方が10倍以上学習コスト(GPU数×時間)がかかるため、基本的には8Bで試行錯誤を繰り返して、良い条件を見出してから8x8Bの学習を行う方向で進めました。8Bと8x8Bは事前学習で学んだデータの重複が多いため、事後学習の条件についても似たような傾向を示すのではないかと考えられるためです。
事後学習の最適化サイクル
JMT-Benchにおいて、国内トップスコアを出すを目標に、各々の開発メンバーが自律的に取り組みました。チーム開発でありながら、コンペのように個人の創意工夫を重視する要素を取り入れたアプローチです。
基本的に以下のサイクルで進めていきました。
- 事後学習に使用可能なモデルのチェックポイントやデータセットをリスト化
- 各メンバーが自由な発想で学習条件(データ・学習パラメータ・学習手法等)を考えて学習・評価を実施
- 学習条件・結果をスプレッドシートにまとめ、チーム内で共有
- チーム内で結果を分析し、苦手分野や足りない能力を調査
- 苦手分野を踏まえたデータ作成 and/or チェックポイントの更新
- 1~5を時間とリソースの許す限り繰り返す
図にすると以下のような感じです。
それぞれのフェーズを個別に検討するのではなく、評価結果のフィードバックを受けて臨機応変にデータ作成・事前学習・事後学習の方向性を変更していきました。
また、事前学習でより多くの知識を獲得するために、事前学習はプロジェクト期限(8月13日)直前の8月上旬まで継続する方針でした。そのため、事前学習のチェックポイントを活用し、事前学習と並行して、事後学習の試行を進めました。SFT, DPOともに基本的にはフルパラメータでの学習を行いました。(後で説明するように、プロジェクト期間終盤ではLoRAでの8x8B学習が活躍しました)
データ作成の方法については、本記事で詳しくは取り上げませんが、主な取り組みとして、高性能なLLM(calm3-22b-chatやNemotron-4-340B-Instruct等)を使用したデータ生成(データ合成)が行われました。
Tanukiの成長過程を時系列で振り返る
さて、前書きが長くなりましたが、振り返っていきたいと思います。
全体のスコア遷移
まずは全体のスコア遷移を見てみます。以下の図は、8BモデルのJMT-Benchスコア(平均値)と試行回数の関係をプロットしたものです。ベースモデルのチェックポイントごとに色分けしています。
Phase1のTanuki-8Bは3.84点なので実質この値がスタートラインになります。Phase1では8Bモデルに対して100万件を超える大量のデータでSFTを行いました。Phase2での目標は、calm3-22b-chatの7.3点を超えて日本一になることです。当初は、(個人的には)かなり無茶な目標だと思っていましたが、最終的には8x8Bでcalm3に匹敵するスコアを達成しました。
#1: 7/11 マイナスからのスタート(8B-SFT 2.76)
記念すべき?初トライでは、8B-base#1(Phase1のベースモデルを約1TBトークン追加学習したもの)に対して約5万件のデータでSFTを実施しました。
avg | coding | extraction | humanities | math | reasoning | roleplay | stem | writing | |
---|---|---|---|---|---|---|---|---|---|
#1 | 2.76 | 1.70 | 1.85 | 3.60 | 1.95 | 2.60 | 4.00 | 2.90 | 3.45 |
データセットはこの時点までにプロジェクト内で作成された比較的高品質な日本語とコードのデータを使用しました。しかし、出力結果を見てみると繰り返しが多く発生しており、チューニングが不十分だったようです。
Phase1のTanku-8B-Instructは3.84点だったので、Phase1のTanukiよりも低いスコアからのスタートとなりました。
#2: 7/13 Phase1のTanukiに追いつく(8B-SFT 3.87)
~数十万件の学習では上手く行かなかったので、Phase1に倣って150万件のデータを2epoch学習させました。
#2ではPhase1のデータ+Phase2で作成済みのデータ(Nemotron-4で生成したものなど)を使用しました。Phase1のデータはMixtral-8x22B-Instructで生成されたデータが中心で、日本語が不自然な低品質なデータも多く含まれています。
結果は以下のようになりました。
avg | coding | extraction | humanities | math | reasoning | roleplay | stem | writing | |
---|---|---|---|---|---|---|---|---|---|
#2 | 3.87 | 3.10 | 3.85 | 4.40 | 1.60 | 4.05 | 5.05 | 4.40 | 4.50 |
Phase1 | 3.84 | 2.90 | 3.70 | 4.00 | 2.65 | 3.35 | 5.20 | 3.95 | 5.00 |
大量のSFTデータで学習することで、回答の形式が改善され、#1と比べると多くのカテゴリでスコアが向上し、Phase1のTanuki(DPOあり)と同等の性能になりました。#2と同じデータセット構成で、50万件×1epoch学習させたモデルと、#2のモデル(150万件×2epoch)の出力例を比較してみます。
50万件の学習では、途中から同じ文の繰り返しが発生してしまっていますが、#2では回答の形式が改善され、自然な文章を生成できています。
当初は、個人的には50万件あればデータ量は十分だと思っていたのですが、ほとんど対話できないベースモデルが上手く対話できるようになるためには、より多くのデータが必要だったようです。この段階では、内容の質を向上させることよりも、まずはQAの基本的な形式を正しく学習させることが優先課題という認識になりました。
#3: 7/16 codingメインの高品質データ(8B-SFT 4.30)
MagpieとEvol-Instructいう手法で大量生成されたcodingデータセット(Aratako/Synthetic-JP-EN-Coding-Dataset-567k)を使用しました。このときのデータ全74万件中56万件をcodingデータが占めています。Phase1の合成データは使用せずに、高品質(と思われる)データを中心に学習しました。
avg | coding | extraction | humanities | math | reasoning | roleplay | stem | writing | |
---|---|---|---|---|---|---|---|---|---|
#3 | 4.30 | 3.90 | 4.50 | 5.55 | 2.95 | 2.40 | 5.30 | 5.20 | 4.60 |
codingデータが中心であったにもかかわらず、humanities、roleplay、stemなどの他カテゴリでもスコアが向上しました。codingデータの学習が他の分野にも良い影響を与えた可能性がありそうです。
また、このとき使用したベースモデル(8B-base#2)では、事前学習に通常使用されるWebの雑多な文章よりも高品質な合成データが学習されていました。ベースモデルの能力が上がったこともあり、#2の150万件よりも少ないデータで指示追従能力を獲得できました。この時期の実験結果から、高性能なベースモデルに対しては、低品質なデータを大量に学習させるよりも、より少ない高品質なデータを学習させた方が性能が上がりやすそうということも分かってきました。
#4: 7/18 5点越え(8B-DPO 5.20)
#3のSFTモデルに対してDPOを適用した結果、初めて5点を超えました。
avg | coding | extraction | humanities | math | reasoning | roleplay | stem | writing | |
---|---|---|---|---|---|---|---|---|---|
#4 | 5.20 | 4.10 | 4.55 | 8.00 | 2.80 | 3.10 | 6.00 | 6.05 | 7.00 |
DPO前後で0.9点と大幅に上昇しています。特に、humanities(5.55→8.00)とwriting(4.60→7.00)のスコアの伸びが大きく、これら作文系のカテゴリでDPOが効果的であったことが分かります。また、このときのDPOデータにはcodingはほとんど含まれていないのですが、SFTで獲得したcoding能力を維持できています。
一方、mathとreasoningはなかなか上がらず、これらの論理的思考力を要する分野の性能向上が課題となりました。
7/21頃 マルチターンへの対応に苦戦
DPOで5点を超えたあたり(SFTでは4点台後半)でスコアが停滞しました。この時期にスコアが伸び悩んだ要因として、マルチターン対話に課題があったことが挙げられます。
当時SFTの学習に使っていたマルチターンのデータの一例を以下に示します。
1ターン目のQAと2ターン目のQAはそれぞれ音楽に関連する会話をしていますが、それぞれターンが独立しており、1ターン目のQAを見なくてもq2に回答できてしまいます。このデータセットを学習に使用したところ、1ターン目で提供された情報や設定が、2ターン目で完全に無視されることがありました。
以下は、このデータセットで学習させたモデルのJMT-Benchでの回答です。
a2では、q2で言及された文脈(「上記の質問」=q1)を無視し、関係のない話題を生成してしまいました。この問題の発覚後は、マルチターンの各ターンが連動したデータを作成し、マルチターン対話への対応能力の改善を図りました。
最終的にこの問題は改善されたものの、プロジェクト期間中にTanukiのマルチターンの性能を十分に向上させることはできませんでした。現在のTanukiはマルチターンの会話が苦手なので、それらの性能向上は今後の課題です。
#5: 7/27 Nemotronの数学・推論データで6点(8B-DPO 6.0)
プロジェクト内では、6/14(土)にNVIDIAよりリリースされたNemotron-4-340B-Instructの推論環境を整備し、合成データ作成に活用する試みが進められていました。Nemotron-4は論理的な問題に対する能力が高いため、主に数学・推論データの生成に活用し、事後学習でそれらのデータの検証を行っていきました。
以下は、それらのデータでDPOを実施した結果(DPO前後の比較)です。
avg | coding | extraction | humanities | math | reasoning | roleplay | stem | writing | |
---|---|---|---|---|---|---|---|---|---|
#5(DPO後) | 6.0 | 4.15 | 5 | 8.65 | 3.6 | 4.4 | 7.35 | 7.15 | 7.7 |
DPO前 | 5.66 | 3.5 | 4.6 | 8.2 | 3.75 | 4 | 7.4 | 7.3 | 6.55 |
性能向上を期待していたmathとreasoningのスコアはほとんど上がりませんでした。データ量・作成方法ともに検討が必要そうです。これらの問題は正解が明確に決まっているため、例えば、途中の計算式が合っていても最終回答が誤っていると低スコア(1~3点)となる特徴があります。論理的な推論が苦手なLLMにとってかなり難しいタスクであり、最後まで課題として残りました。(問題によっては人間にとっても難しいかもしれません)
また、実際の出力を確認すると、Nemotron-4のデータ投入によって、数字やマークダウン形式を使用した箇条書き項目を生成するような形式の変化が見られました。gptはこのような構造化された回答に対して(内容が同じであっても)高評価を付ける傾向があり、回答を形式化することでスコアが向上した可能性があります。
一方で、Nemotron-4のデータを学習させることで、文章のスタイルが形式以外でも堅苦しくなったという変化もありました。生成された日本語の文章を見ると、calm3の方がより自然で親しみやすい表現を生成する傾向がありました。
これらの結果&その後の検証から、論理系のタスクはNemotron-4 or WizardLM-2-8x22B、作文系のデータはcalm3で生成、というようにタスクの特性に応じてデータ生成に使用するモデルを使い分けると良さそうということが分かってきました。
#6: 8/3 8x8B のポテンシャルが見え始める(8x8B-SFT 6.17)
8x8Bの事前学習がある程度進んだため、事後学習も動き始めました。
以下は、8x8Bと8Bで同じデータを学習した結果です。
avg | coding | extraction | humanities | math | reasoning | roleplay | stem | writing | |
---|---|---|---|---|---|---|---|---|---|
#6 (8x8B) | 6.17 | 5.40 | 4.70 | 8.80 | 4.45 | 4.80 | 7.30 | 7.35 | 6.55 |
8B | 5.86 | 3.65 | 4.95 | 8.50 | 4.45 | 4.05 | 7.20 | 7.30 | 6.75 |
全体的には8x8Bの方がスコアが高く、特にcoding能力には大きな差がありそうです。
ラストスパートに向けて、8x8Bへの計算リソース配分を増やしていくことになりました。
#7: 8/4 学習率を下げてスコアup(8B-SFT 6.01, 8x8B-SFT 6.39)
8BモデルのSFTにおいて、学習率を5e-6から1e-6に下げたところ、JMT-Benchスコアが5.85から6.01に上昇しました。
avg | coding | extraction | humanities | math | reasoning | roleplay | stem | writing | |
---|---|---|---|---|---|---|---|---|---|
#7-1 (1e-6) | 6.01 | 3.25 | 5.20 | 8.50 | 3.85 | 5.00 | 7.05 | 7.55 | 7.65 |
5e-6 | 5.85 | 3.70 | 5.05 | 8.65 | 4.10 | 4.00 | 7.10 | 7.25 | 6.95 |
8x8Bでも学習率を下げてスコアが上がりました。
avg | coding | extraction | humanities | math | reasoning | roleplay | stem | writing | |
---|---|---|---|---|---|---|---|---|---|
#7-2 (5e-7) | 6.39 | 5.45 | 5.30 | 9.20 | 4.65 | 4.35 | 7.35 | 7.65 | 7.20 |
5e-6 | 5.99 | 5.00 | 4.30 | 9.15 | 3.55 | 4.20 | 7.50 | 7.60 | 6.65 |
学習率が高すぎると、ベースモデルが既に獲得している有用な知識や能力を損なう破滅的忘却を引き起こすことが懸念されます。ベースモデルが高性能になってきたので、モデルの知識を維持しつつ、SFTでは出力形式の微調整程度に留めるのが良さそうということが分かってきました。
#8: 8/6 事前学習が偉大すぎる(8B-SFT 6.36)
Tanukiの苦手分野が分かってきたので、codingや論理系のデータを中心に追加事前学習を行いました。以下は、追加事前学習前後のベースモデルについて、同じ条件でSFTした比較結果です。
avg | coding | extraction | humanities | math | reasoning | roleplay | stem | writing | |
---|---|---|---|---|---|---|---|---|---|
#8 | 6.36 | 4.80 | 5.20 | 8.55 | 4.75 | 4.35 | 7.90 | 7.75 | 7.60 |
追加学習前 | 5.86 | 3.45 | 5.25 | 8.60 | 4.60 | 3.85 | 7.20 | 7.45 | 6.50 |
追加事前学習によって大量のデータを学習させることで、苦戦していたreasoningとcodingの性能が改善してきました。
#9: 8/9 8B-SFT ベストスコア(8B-SFT 6.69)
8Bの事前学習が完了後、SFTデータの組み合わせなどを試行錯誤し、最終的には6.69が8B-SFTのベストスコアになりました。最終的には公式のNejumi LLMリーダーボード Neoでの性能も確認する必要があったため、gpt-4でも評価しました。gpt-4oとgpt-4での評価結果はそれぞれ以下のとおりです。
avg | coding | extraction | humanities | math | reasoning | roleplay | stem | writing | |
---|---|---|---|---|---|---|---|---|---|
#9-4o | 6.69 | 4.75 | 5.90 | 8.65 | 5.75 | 5.90 | 7.25 | 8.15 | 7.20 |
#9-4 | 6.67 | 4.45 | 4.80 | 9.30 | 4.75 | 6.85 | 7.60 | 7.65 | 7.95 |
SFTに使用されたデータセットは以下のとおりです。calm3で生成されたデータを日本語の教師データとして使用しました。
- kanhatakeyama/ramdom-to-fixed-multiturn-Calm3: 10078 samples
- kanhatakeyama/AutoMultiTurnByCalm3-22B: 55000 samples (1ターン目のみ使用)
- Synthetic-JP-EN-Coding-Dataset-687k: 50000 samples(ライセンス確認中のため、本記事執筆時点では非公開)
- Aratako/Synthetic-JP-EN-Coding-Dataset-567kと同様の手法(Magpie+Evol-Instruct)で作成されたcodingデータ
- Synthetic-Calm3-MT-Coding-137k: 25000 samples(ライセンス等確認中のため、本記事執筆時点では非公開)
- Synthetic-JP-EN-Coding-Dataset-687kの日本語データをマルチターン化したもの
学習パラメータ等は以下のとおりです。
data 件数 | lora/full | epochs | learning rate |
---|---|---|---|
140078 | full | 1 | 5e-6 |
#10: 8/10 8x8B-SFT ベストスコア(8x8B-SFT 6.69)
8x8Bでも8B-SFTベスト(#9)と同じデータを学習させ、ベストスコアを達成しました。
avg | coding | extraction | humanities | math | reasoning | roleplay | stem | writing | |
---|---|---|---|---|---|---|---|---|---|
#10-4o | 6.69 | 5.00 | 5.80 | 9.05 | 5.70 | 4.50 | 8.05 | 7.65 | 7.80 |
#10-4 | 7.19 | 6.15 | 6.10 | 8.50 | 5.05 | 5.55 | 9.00 | 8.65 | 8.50 |
gpt-4o評価では8B-SFTと同スコアでしたが、gpt-4評価の結果は8B→6.67, 8x8B→7.19となっており、性能差がありそうです。8Bはgpt-4o評価に過剰適合してしまったのかもしれません。
学習パラメータ等は以下のとおりです。
data 件数 | lora/full | epochs | learning rate |
---|---|---|---|
140078 | full | 1 | 5e-7 |
#11: 8/11 Tanuki-8B-dpo-v1.0の誕生(8B-DPO 6.63)
プロジェクト終了まであと2日となったところで、#9のモデルに対してDPOを行った8Bモデルがgpt-4評価で最高スコアを達成しました。このモデルがTanuki-8B-dpo-v1.0としてリリースされることになります。
avg | coding | extraction | humanities | math | reasoning | roleplay | stem | writing | |
---|---|---|---|---|---|---|---|---|---|
#12-4o | 6.63 | 5.2 | 5.4 | 9.05 | 4.65 | 4.9 | 7.55 | 8.15 | 8.15 |
#12-4 | 7.24 | 5.4 | 6.65 | 9.1 | 3.9 | 5.75 | 8.75 | 9.35 | 9.05 |
データセットには、以下が使用されました。(ライセンス等確認中のため、本記事執筆時点では非公開)
- aya-ja-evol-instruct-calm3-dpo: 30295 samples
- aya_datasetの日本語データに対してEvol-instructを適用し、質問を複雑化
- chosenをcalm3, rejectedを過去のTanuki-8Bで生成
学習パラメータ等は以下のとおりです。
data 件数 | lora/full | epochs | learning rate | beta |
---|---|---|---|---|
30295 | full | 1 | 5e-7 | 0.01 |
8/11 LoRAで8x8B高速サイクル回転作戦に切り替え
ここで、以下の理由から8Bの学習を終了し、残り2日間は8x8Bの事後学習に注力することになりました。
- 8Bのスコアが伸び悩んできた
- 8x8Bの試行錯誤が足りていない(=性能向上の余地がある?)
- 事前学習で同じデータを学習させた際、8x8Bの方が8Bよりもlossが0.1以上小さかったため、8x8Bの方が性能が高いと期待される
しかし、これまで実施していた8x8Bモデルのフルパラ事後学習には8ノード(= H100×64枚)で5時間以上かかり、試行錯誤の回数を増やすことが困難でした。
そこで、より効率的に実験を進めるため、LoRAを用いた学習戦略に切り替えることにしました。
LoRAを採用することで、フルパラメータ学習の10分の1以下の計算コストで学習が可能となり、より多くの実験を行えるようになります。また、LoRAとフルパラで学習させたときの性能差が小さいことも確認できていたため、LoRAの方が学習効率が良いと考えられました。
残りの期間は#11の8x8B-SFTモデルに対してLoRAでのDPOをメインで実施し、8x8Bモデルの性能向上に取り組みました。
#12: 8/12 最後はハイパラ職人によってTanuki-8x8B-dpo-v1.0が誕生(8x8B-DPO 7.04)
LoRAに切り替えたことで試行錯誤のサイクルが加速する中(1日で50run以上)、開発期間終了まで30時間を切ったところでTanuki-8x8B-dpo-v1.0が誕生しました。ハイパーパラメータの微調整も含めて丁寧にチューニングすることによって性能を引き出し、高スコアを達成しました。
Tanuki-8x8B-dpo-v1.0とcalm3を同じリーダーボードで評価した結果は以下のとおりです。
avg | coding | extraction | humanities | math | reasoning | roleplay | stem | writing | |
---|---|---|---|---|---|---|---|---|---|
#13-4o | 7.04 | 5.00 | 6.20 | 9.00 | 6.90 | 5.75 | 8.25 | 7.55 | 7.70 |
calm3-4o | 7.31 | 5.40 | 7.05 | 9.25 | 6.60 | 5.85 | 8.55 | 8.05 | 7.75 |
#13-4 | 7.91 | 6.75 | 6.90 | 9.30 | 5.75 | 7.35 | 8.95 | 9.40 | 8.85 |
calm3-4 | 7.81 | 5.20 | 8.15 | 9.40 | 6.10 | 6.80 | 8.55 | 9.35 | 8.90 |
gpt-4oとgpt-4評価でぶれがあるものの、avgスコアを平均するとTanukiは7.48、calm3は7.56なので、(JMT-Benchに関しては)日本トップレベルに到達したと言って良さそうです。懸命に追い続けてきたcalm3にここまで近づけたとは、非常に感慨深いものがあります。
データセットには以下が使用されました。(ライセンス等確認中のため、本記事執筆時点では非公開)
- aya-ja-evol-instruct-calm3-dpo: 10000 samples
- aya_datasetの日本語データに対してEvol-instructを適用し、質問を複雑化したもの(8Bで使用されたものと同じ)
- synth-dpo-basic-reasoning-nemotron-4: 10000 samples
- Tanuki-8BのSFTモデルに推論問題の回答を複数生成させ、Nemotron-4で好ましい回答を判定させたもの
- synth-dpo-basic-math-nemotron-4: 10000 samples
- Tanuki-8BのSFTモデルに計算問題の回答を複数生成させ、Nemotron-4で好ましい回答を判定させたもの
学習パラメータ等は以下のとおりです
data 件数 | lora/full | epochs | learning rate | beta |
---|---|---|---|---|
30078 | lora | 2 | 2e-6 | 0.1 |
ちなみに、本プロジェクトではJasterなどの選択問題系のタスクは考慮していないため、それらを含めた総合的な評価ではcalm3の性能には及びません。
しかしながら、ChatbotArenaようなシステムでの対話試験&人手評価の結果ではcalm3を上回り、gpt-4o-miniやGemini-1.5-flashといった海外モデルに匹敵する性能を示しました。
詳細は、以下の記事をご参照ください。
このように、JMT-Benchを主要な指標としてモデルを最適化し続けたことで、Tanukiを対話能力に優れたモデルへと成長させることができました。
得られた知見
様々な検証の結果、分かったことをまとめていきます。
高性能ベースモデルにはデータ量より質
Phase1~Phase2初期段階では、データの質よりも量を重視する戦略を採用していましたが、ベースモデルの性能向上に伴い、この手法の限界が明らかになりました。
以下は、モデルのチェックポイントを変えて、その他はLoRAとFullでそれぞれ同じ条件で学習させた結果です。データはPhase1で使用された中~低品質のデータ10万件を使用しました。
学習トークン数が増え、ベースモデルの性能が上がると、LoRAとフルパラのスコアが逆転する現象が見られました。フルパラSFTでは、過度な調整によって既存の重要な知識がうまく取り出せなくなった可能性があります。一方、LoRAではフルパラメータ学習と比較してパラメータの変化が限定的であるため、この問題が起こりづらかったものと考えられます。
LIMA(Less Is More for Alignment)論文でも指摘されているように、データの質は量に勝るという知見があります。しかし、「高品質」なデータの定義も難しいポイントです。
一般的に、人手で丁寧に作成・校正したデータは「高品質」と考えられがちです。しかし、私たちの実験では以下のように異なる結果が得られました。
- ChatBotArena的なシステムで人間が高評価した回答をデータに加えてSFTしたところ、追加前と比べてスコアが低下した(5.83→5.66)
- ChatBotArenaのデータを人手で確認・校正したところ、さらにスコアが低下した(5.66→5.57)
- 人手で一から作成されたichikaraデータセットを使用した際も良い結果が出なかった
さらに、LLMを使ってデータを品質フィルタリングする手法(Ask-LLMなど)も試しましたが、こちらも簡単ではありませんでした。フィルタリングによってデータの多様性が失われ、結果としてモデルの汎化性能が低下した可能性が考えられます。
本プロジェクトでは、calm3やNemotron-4といった高性能なモデルが生成したデータを中心に事後学習に使用することで、性能を上げることができました。
Out-of-Distribution(OOD)に要注意
実験を進める中で、事前学習と事後学習のデータ分布の一貫性が重要であることが分かってきました。つまり、事前学習で使用したデータ分布に対して、事後学習で使用するデータがOut-of-Distribution(OOD)にならないことが重要と考えられます。
前項と重複する部分がありますが、OODを引き起こしてしまったと考えられる失敗例として以下があります。
- 人手で作成・校正したデータ(ichikara, chatbotarena)でスコア低下した
- 人手でキュレーション(文頭の枕詞「もちろん!」を削除するなど)したデータでもスコア低下した
事前学習の終盤では、複数のLLMの生成データを主に学習させており、Tanukiもそれらの分布に近づいていたと考えられます。そのため、人手で作成したデータを使用したり、LLMの生成文章に人手を加えると、それによって分布のズレが発生し、学習が上手く行かず、性能が低下してしまったのかもしれません。
OODの問題を回避するために、最終的には以下の工夫を行いました。
- 事前学習の段階で、ありとあらゆるSFTデータを学習させる(つまり、SFTデータは実質学習済みの状態にする)
- DPOのchosenデータの生成には、事前学習の合成データ生成に使用したモデル(calm3, Nemotron-4等)を活用
これによって事後学習データが事前学習データの分布から大きく逸脱することを防ぎ、OODの問題を最小限に抑えることができました。
また、高品質なデータの定義はベースモデルによって異なりそうです。人手で作成した高品質なデータを使用して性能を上げるには、事前学習の段階でそれに近い分布を学習しておく必要があると考えられます。
OODを回避するには、事前学習から事後学習まで一貫したデータ戦略が必要です。特に、事前学習で使用したデータの分布や特性を十分に理解し、それに整合する形で事後学習のデータを選択または生成することが、安定した性能向上につながると考えられます。
対話性能の向上には長くて丁寧な応答が有効
実験を重ねる中で、長めの応答中心のデータセットを学習に使用することが対話性能の向上に効果的であることが分かりました。
最初の頃はNemotron-4で回答生成したdatabricks-dolly-15k-jaデータセットをSFTに使用していましたが、このデータセットには以下のように回答が短めのデータが多く含まれていました。
指示(入力) | 応答 |
---|---|
アップル社の製品ラインナップを教えてください。 | アップル社は、腕時計であるApple Watch、スマートフォンであるiPhone、ノートパソコンのMacBook、デスクトップパソコンのiMacなど、様々な製品を販売しています。 |
このような短い応答を含んだデータセットで学習したモデルと、dollyを除き、より長く丁寧な応答を中心としたデータセットで学習したモデルの出力を比較してみました。以下は、その比較結果の一例です。
dollyありで学習したモデルは指示に対して端的に回答できているものの、小説の序章にしては短すぎるかもしれません。一方、長い応答をメインで学習させたモデルは、(内容関してつっこみどころはあるかもしれませんが)全体として序章らしい構成と雰囲気を捉えています。
状況や質問の性質に応じて簡潔な回答が適切な場合もあるかもしれませんが、多くの場合、情報が少ないよりは多い方が有利です。人間は提供された情報を自身のニーズに合わせて取捨選択できるためです。
長い応答中心の学習データ構成にすることで、Tanukiは対話能力の高いモデルになりました。
事前学習は偉大
事後学習でスコアが伸び悩んだ際、苦手分野のデータを中心に追加事前学習を行うことで、スコアが大幅に改善されることがありました。
以下は、JMT-Benchのカテゴリごとのスコア遷移を示したグラフです。
特に、MathとCondingでは、ベースモデルの更新に伴い、(特にオレンジ色のbase#5あたりから)スコアが顕著に伸びてていることが分かります。base#5はbase#4にcodingとmathを含んだQAデータを追加事前学習させたモデルです。
base#4と#5で試行されたSFTモデルのmathスコアとcodingスコアの平均値は以下のとおりです。
coding | math | |
---|---|---|
base#4 | 3.62 | 4.06 |
base#5 | 4.69 | 5.07 |
SFTの学習条件が変わっているものも含まれているため単純な比較はできませんが、追加事前学習によって性能向上が加速し、coding とmathそれぞれ約1点ずつスコアが上がりました。
事前学習とSFTはどちらもnext token predictionで次のトークンを予測するタスクを行っているため、QAデータの学習において本質的な違いはないと考えられます。今回は使用していたライブラリの特性上、事前学習の方がSFTよりも数十倍以上高速に学習を進められたため、事前学習でQAデータを大量に学習させました。これによって知識の取り出し方を効率的に学習できたものと考えられます。
その他の知見
以下は、その他の細々とした知見共有です。(詳細な分析ができていないものもあります)
- ハイパラ調整は結構大事
- 学習率の変更(5e-6 → 5e-7)で0.4点up(8x8-DPO-LoRA)
- DPOのbetaの変更(0.5 → 0.1)で0.6点up(8x8-DPO-LoRA)
- ベースモデルの学習段階によってDPOの最適なバッチサイズが変化
- 序盤(6点くらいまで)は大きめ(512〜1024)で学習が安定化
- 終盤は小さめ(256)でも成功
- 言語・タスクに応じたデータ生成手法の選定
- コード(日・英)ではMagpieによる合成データが有効だった
- 日本語文章の生成はMagpieで上手く行かない場合があった(calm3, Nemotron-4を使用)
- calm3, Nemotron-4を使用したPersona手法(職業ロールプレイ等)では成功例あり
- rejectedデータの質の向上
- TanukiでDPOのrejectedデータを生成した際に、繰り返しなどが発生し、rejectedデータが低品質すぎてしまうことがあった(特に、数学・推論系の高難易度タスクで発生)
- そのようなデータを使用すると、モデルが単純に繰り返しパターンを避けるだけの学習をしてしまう可能性がある
- 人手 or ルールベースで繰り返し文のフィルタリングを実施することでrejectedデータの質を改善
おわりに
Tanuki-8B-v1.0とTanuki-8x8B-v1.0が完成するまでに、絶え間ない試行錯誤がありました。改めてSlackのスレッドに積み重なった1500件を超える投稿を読み返しながら振り返ると、日々の議論や試行錯誤の過程が鮮明に蘇りました。プロジェクト中、実際に試してみたものの上手くいかなかった学習方法や、作成したけれど最終的には採用されなかったデータセットが山ほどあります。しかし、これらは決して無駄ではなく、より適切な方向性を探るために必要な取り組みでした。理論的には上手く行くと言われていることを実際に試してみると上手くいかなかったり、何が良いのか分からないことが多々あり、試行錯誤の重要性を実感しました。
本記事で共有した知見が、LLM開発に興味を持つ方々や、類似のプロジェクトに取り組む方々にとって有益な情報となれば幸いです。
謝辞
本記事を執筆するにあたって、Atsushi SaitoさんよりOODに関する議論やDPOについての知見を共有いただきました。
この成果は、NEDO(国立研究開発法人新エネルギー・産業技術総合開発機構)の助成事業
「ポスト5G情報通信システム基盤強化研究開発事業」(JPNP20017)の結果得られたものです。
東京大学 松尾・岩澤研究室が運営する松尾研LLMコミュニティのLLM開発プロジェクト[GENIAC] の開発記録、情報発信になります。 各種リンクはこちら linktr.ee/matsuolab_community
Discussion