📝

How to Learn From Mistakes on Grade-School Math Problems

2024/08/31に公開

https://www.arxiv.org/abs/2408.16293

https://huggingface.co/papers/2408.16293

要約

  1. はじめに

言語モデルは推論タスクで高い性能を示すが時折間違いを犯す。近年、多段階プロンプトによる自己修正に関する研究が活発化している。本研究では事前学習段階で「エラー修正」データを直接組み込む有用性を検討する。合成数学データセットを用いてエラーを含むデータで学習したモデルが、エラーがないデータで学習したモデルより高い精度を達成できることを示す

  1. 合成数学データ

Ye et al. が導入した iGSM データセットを使用。このデータセットは GSM8K を模倣しつつ算術の難しさと常識知識を除去し、論理的推論に焦点を当てている。各問題はパラメーターの依存グラフを持ち、完全に検証可能な解答を生成できる

  1. 結果 0-1 : 言語モデルはリグレット時にリトライできる

事前学習済みモデルは内部状態に「リグレット」パターンを示す。これを利用して生成プロセスを改善できるか実験を行った。「リグレット時のリトライ」手法は推論精度を向上させたがビーム探索以上の改善は限定的だった

  1. 結果 2-6 : リトライデータでの事前学習

リトライデータ(エラーと即時修正を含む)で事前学習したモデルはエラーがないデータで学習したモデルより高い推論精度を達成した。高いエラー率のデータでも学習中にエラーを生成する傾向は見られなかった。この手法はビーム探索や「リグレット時のリトライ」とは根本的に異なることが示された

  1. 結果 7 : リトライデータでの微調整

エラーのないデータで事前学習したモデルに対しリトライデータを用いて LoRA 微調整を行った。しかしリトライデータで直接事前学習した場合ほどの精度向上は見られなかった。エラー修正スキルはエラーがない推論とは大きく異なり、少量のパラメーター更新では獲得できないことが示唆された

  1. 結果 8 : 偽の間違いでの事前学習

完璧なリトライデータの準備は現実的でない場合がある。そこで正解のみの数学問題に「偽の間違い」を追加する簡易的な方法を探索した。この手法でも精度向上が見られ、実用的なアプローチとなる可能性が示された

  1. 結論

本研究は言語モデルの事前学習データに間違いと修正を含めることの重要性を示した。エラー修正スキルは単純なビーム探索や微調整では獲得できない

Abstract

言語モデルは推論タスクで高い性能を示すが時折間違いを犯す。本研究では事前学習段階で「エラー修正」データを直接組み込む有用性を検討する

合成数学データセットを用いた実験によりエラーを含むデータで学習したモデルがエラーのないデータで学習したモデルより高い精度を達成できることを示した。この手法はビーム探索やエラー検出に基づくリトライとは異なりモデルに直接エラー修正能力を獲得させる

またエラー修正スキルはエラーがない推論とは大きく異なり、少量のパラメーター更新による微調整では獲得できないことも明らかになった

1 Introduction

言語モデルは様々なタスクで人間に近い性能を示すが問題解決スキルにはまだ不完全さがある。最近の研究では言語モデルの推論精度向上のため検証器を用いる手法が注目されている。特に言語モデル自身が自己検証を行う手法が興味深い

これらの研究から以下の疑問が生じる :

  1. モデルが後から修正できるなら、なぜ最初から間違えるのか
  2. なぜ生成中ではなく生成後に修正するのか

本研究では直接エラー修正を行うモデルの訓練可能性を探る。具体的には以下の問題に取り組む :

リトライデータ(エラーと即時修正を含む)での訓練が、エラーがないデータでの訓練と比較して高い推論精度を達成できるか

制御された実験を行うため iGSM データセットを使用する。このデータセットは小学校レベルの数学推論問題を大量に生成でき、エラーと修正を確実に作成できる利点がある

2 Synthetic Math Data From Prior Work

Ye et al. が導入した iGSM データセットは GSM8K を模倣しつつ、算術の難しさと常識知識を除去し、論理的推論に焦点を当てている。各問題は構造グラフとパラメーターの依存グラフを持ち、完全に検証可能な解答を生成できる

データセットの特徴 :

  • パラメーターは依存グラフを形成し、先行するパラメーターが計算されてから次のパラメーターが計算可能
  • 算術の難しさを排除するため計算は二項演算に分解される
  • op は解答に必要な演算数を表す
  • iGSM-med と iGSM-hard の 2 つのファミリーがあり、それぞれ訓練用と OOD テスト用のデータセットを含む
  • reask データも導入され、OOD 評価に使用される

著者らは GPT-4/GPT-4o でも op > 10 の問題は解けないことを示し、このデータセットが非自明な難しさを持つことを証明している

3. Result 0-1 : Language Models Can Retry Upon Regret

  • モデルは間違ったパラメーターを生成した直後に その誤りを認識できる
  • can_next プロービングを用いて この「リグレット」を検出できる
  • 「リグレット時のリトライ」手法を導入し生成プロセスを改善
  • 実験結果 : この手法は推論精度を向上させたがビーム探索以上の改善は限定的
  • エラー検出の精度が非常に高い場合にのみ大幅な精度向上が見られた

結論

  1. 「リグレット時のリトライ」は推論精度を向上させるがその効果は限定的
  2. この手法は推論プロセスを複雑にし、理想的な「一般知能」フレームワークからは外れる
  3. エラー修正は単純なリトライや再生成とは異なるスキルであることが示唆された

3.1 Result 0 : Models Can Be “Regretful” After Making Mistakes

事前学習済みモデルの内部状態に「リグレット」パターンが存在することを示した。can_next プロービングを用いてモデルが次に計算可能なパラメーターを 99% の精度で予測できることを発見。また間違いを犯した直後、モデルは約 60% の確率でその間違いを認識していることが分かった

3.2 Result 1 : Let Models Retry Upon Regret

can_next プロービングを用いてモデルの生成プロセスを制御する実験を行った。各文生成後にエラーを検出した場合、前の文の末尾に戻って再生成を行う。この「リグレット時のリトライ」手法は推論精度を向上させビーム探索を上回る結果を示した。しかし

  • 改善は限定的で、エラー検出器の精度に大きく依存する
  • 推論複雑性が増加し、単一モデルによる自動回帰的デコーディングという理想から離れる

これらの結果はエラー修正が単純な再生成や確率的探索とは異なるスキルであることを示唆している

4 Result 2-6 : Pretrain with Retry Data

リトライデータ(エラーと即時修正を含む)を用いた事前学習の効果を検証した

結果 2-3

  • エラー率が高いほど(一定範囲内で)モデルの性能が向上した
  • エラー部分にラベルマスクを適用する必要はなかった

結果 4

  • リトライデータで学習したモデルはテスト時にほとんどリトライを行わなかった
  • retry_rate が非常に高い場合のみリトライ回数が増加した

結果 5

  • リトライデータで学習したモデルは最短の解答を出力する能力を維持した

結果 6

  • エラー修正スキルはビーム探索や「リグレット時のリトライ」とは根本的に異なることが示された
  • 例えば iGSM-med_op=23_pq の場合「リグレット時のリトライ」では 78% から 80% への精度向上にとどまったがリトライデータでの事前学習では 95% まで向上した

これらの結果はエラー修正が単純な確率的探索とは異なる本質的なスキルであることを示唆している。モデルの推論能力を真に向上させるには訓練データにエラーと修正を含める必要があると考えられる

5 Result 7 : Finetune with Retry Data

エラーが無いデータで事前学習したモデルに対しリトライデータを用いて微調整を行う実験を実施した

主な結果

  • LoRA などのパラメーター効率が良い微調整手法を用いてもリトライデータで直接事前学習した場合ほどの精度向上は見られなかった
  • LoRA のランクを小さくした場合、エラーが無いデータで事前学習したモデルよりも性能が低下した
  • 十分な量のリトライデータを用いた全パラメーター微調整では精度向上が見られたが、これは実質的に事前学習の継続と同等である

これらの結果から以下の結論が導かれた

エラー修正スキルは元のエラーがない推論とは大きく異なるスキルであり、エラーがないデータで事前学習されたモデルに対する LoRA 微調整では獲得できない

この知見はエラー修正能力を獲得するには リトライデータを事前学習段階で組み込む必要があることを示唆している。微調整段階での導入では不十分である可能性が高い

6 Result 8 : Pretrain with Fake Mistakes

完璧なリトライデータの準備が現実的でない場合を想定し、より簡易的な手法を探索した

提案された 2 つのアプローチ

  1. retry_weak : 解答の後半から文をランダムに選び、リトライとして挿入
  2. retry_miss : 問題文から未出のパラメーターをランダムに選び、リトライとして挿入

主な結果

  • retry_weak 形式のデータを用いた事前学習で精度が大幅に向上した
  • retry_miss データは retry_weak に比べ精度向上が小さかった
  • どちらの手法でもモデルは最短解答を生成する能力を維持した

これらの結果は完璧なリトライデータがなくても簡易的な方法で生成した「偽の間違い」データが有効であることを示している。特に retry_weak 手法は実装が容易で実用的なアプローチとなる可能性が高い

この知見は将来の大規模言語モデル開発において補助モデルを用いて数学データに偽の間違いを挿入するなど、事前学習データの拡張方法に示唆を与えている

7 Conclusion

本研究は言語モデルの事前学習データに間違いと即時修正を含めることの有効性を示した

主な知見

  • リトライデータで学習したモデルは同量のエラーがないデータで学習したモデルより高い推論精度を達成
  • リトライデータでの学習は安全で、モデルはエラーを生成しやすくならない
  • エラー修正スキルはビーム探索や「リグレット時のリトライ」とは本質的に異なる
  • このスキルはエラーがない推論とは大きく異なり、LoRA 微調整では獲得できない
  • 完璧なリトライデータがなくても簡易的な「偽の間違い」データで効果が得られる

これらの結果は将来の大規模言語モデル開発においてリトライデータを事前学習に含めることの重要性を示唆している。エラー修正を微調整や生成時の工夫で獲得させようとするアプローチは効果が限定的である可能性が高い

本研究は制御された合成データを用いて実験を行ったが得られた知見は実際の言語モデル開発に応用できる可能性がある。特に補助モデルを用いて数学データに偽の間違いを挿入するなど、事前学習データの拡張方法に示唆を与えている

Discussion