【論文紹介】LLMに何度も解答を生成させて正答率を上げる反復サンプリング ~ Large Language Monkeys
こんにちは。ZENKIGENデータサイエンスチームの栗原です。
OpenAIから新たなモデルOpenAI o1が発表され話題になっています。
モデル構築における技術や手法の詳細は明らかにされていませんが、強化学習やChain-of-Thought、推論時の計算量によるスケーリングといったトピックが関連していると思われています。今回はこれらトピックの中から、推論計算量のスケーリングに関する研究として、『Large Language Monkeys: Scaling Inference Compute with Repeated Sampling』を紹介したいと思います。
概要
- これまで言語モデルは学習の計算量を大きくすることで性能を向上させてきましたが、本研究は推論(解答生成)回数を多くする(反復サンプリング)ことで正答を得られる率を上げられることを示しました。推論コストの低い小さなモデルに対し反復サンプリングを行うことで、GPT-4oやClaude 3.5 Sonnetより費用的に安くより正確な解答を得ることができました。
- 正答率とサンプル数の間にはしばしば対数線形の関係があり、指数冪乗則でモデル化でき、推論時間のスケーリング則が存在することが示唆されました。
- 今後の研究の方向性として、反復サンプリングにより得られた解答候補の中からどれが正しい解答かを検証する高精度な検証器の構築が重要だと述べられています。
1. 導入
大規模言語モデルの性能は、モデルのパラメータ数、事前学習データセットのサイズ、広範な事後学習により飛躍的に向上してきました。
推論性能を上げる方向性としてこれまでは学習時の工夫が様々行われてきましたが、推論時は基本的に入力に対し1度しか推論させない制限をとってきました。
本研究では何度も推論させる(反復サンプリング)ことで推論性能を向上できないか調査しました。
図1: 本論文での反復サンプリングのフロー。Step 1でLLMに入力Promptに対する多くの解答候補を生成させる(temperatureを正の値として)。Step 2でドメイン依存の検証器で解答候補の中から最終解答を決定する。(論文より引用)
反復サンプリングの有効性を示すには以下の2観点があると述べられています。
-
Coverage
サンプル(解答候補)を使用して問題を解決できる割合。 -
Precision
生成されたサンプルの中から最終解答を選択する状況下で正しいサンプルを特定できる割合。
上記観点に着目し、推論コストをかける(何度も解答生成させる)ことで実際どの程度推論性能が上がっていくか見ていきましょう。
2. 反復サンプリングのスケーリング
本研究がターゲットとするタスクは正誤が明確なタスクです(文章の質の評価などではなく数学やプログラミングなど正誤が明らかなもの)。
具体的には以下の5つのタスクです。
-
GSM8K
小学生レベルの算数の問題のデータセット。ランダムに128の問題を取得し評価。 -
MATH
GSM8Kよりも難しい数学の問題のデータセット。ランダムに128の問題を取得し評価。 -
MiniF2F-MATH
"proof checking language"(証明支援系の言語)で形式化された数学の問題のデータセット。Lean4を言語として使用し、MATHデータセットから形式化された130問で評価。 -
CodeContests
競技プログラミングのデータセット。モデルにはPython3で解答するように強制。 -
SWE-bench Lite
実際のGitHubのissueを集めたデータセット。説明文とリポジトリのスナップショットで構成されており、モデルはコードベース内の1ファイルを編集して問題を解決することが求められる。
これらタスクのうち、MiniF2F-MATH、CodeContests、SWE-bench Lite にはそれぞれ自動検証ツールが存在し、正誤を自動判定でき、Coverageは一般的に利用される
GSM8K と MATH では、適切な検証器は存在しないためどのサンプルが正しい最終回答を出力して "pass" するかを確認してCoverageを測ります。
それぞれの問題
2.1 反復サンプリングはタスク横断で有効
それでは、反復サンプリングがCoverageをどのように向上させるか見ていきます。
Llama-3-8B-Instruct と Llama-3-70B-Instruct を使用し、CodeContests、MiniF2F、GSM8K、MATHの各問題に対して10,000個のサンプルを生成させます。
SWE-bench Lite は Llama-3 モデルでは必要なコンテキスト長がモデルの最大コンテキスト長を超えるため、DeepSeek-Coder-V2-Instruct を使用して検証します。
SWE-bench の問題をLLMに解かせるためには、LLMにコード編集をさせるためのツールを利用するのが一般的であり、本研究ではオープンソースの Moatless Tools を使用し、問題を解くためのLLMとツールのやり取りは250回に制限して行ったとのことです。
以下に結果を示します。
図2: 5タスクにおける、解答生成サンプル数を増やしていったときのCoverageの変化と、GPT-4oとSOTAモデルの1回の解答生成精度の比較。(論文より引用)
図2左はSWE-bench Liteの結果で、青実線が今回の反復サンプリングによるサンプル数を増やしていくに従ってCoverageがどのように変化していくかの結果、黒点線が現状の1回の解答生成でのSOTAのモデル精度、赤点線が1回の解答生成でのGPT-4oの精度です。
1回の生成では DeepSeek-Coder-V2-Instruct はSOTAモデルやGPT-4oより低いCoverageですが、生成サンプル数を増やしていくとSOTAモデルを上回るCoverageを示しています。
図2右はCodeContests、MiniF2F、GSM8K、MATHそれぞれのタスクにおいて Llama-3系のモデル(青実線: 8B, 緑実線: 70B)の生成サンプル数を増やしていった際のCoverageの変化です。
赤点線は1回の解答生成でのGPT-4oの精度であり、どちらのモデルサイズのLlama-3系モデルも生成サンプル数を増やしていくことでGPT-4oを上回っていることがわかります。
2.2 反復サンプリングはモデルサイズ・モデルファミリー横断で有効
反復サンプリングはLLMのモデルサイズやモデルファミリーに依存せず有効であることを述べています。
以下のモデルファミリーとサイズで検証しています。
- Llama 3: Llama-3-8B, Llama-3-8B-Instruct, Llama-3-70B-Instruct
- Gemma: Gemma-2B, Gemma-7B
- Pythia: Pythia-70M, 160M, 410M, 1B, 1.4B, 2.8B, 6.9B, 12B
結果は以下の通りです。
図3: 様々なモデルファミリー・モデルサイズでの解答生成サンプル数ごとのCoverageの変化。(論文より引用)
基本的に解答生成サンプル数を増やすことでCoverageが上がっていますが、CodeContestsにおいてはPythiaモデルはどのサイズにおいても10,000サンプル生成してもCoverageはゼロのままだったとのことです。
考察として、PythiaはLlamaやGemmaよりもコーディングに特化したデータで学習されていないためではないか、と述べられています。
2.3 反復サンプリングは性能とコストのバランスを取る助けとなる
2.1, 2.2の結果より、弱いモデルでも反復サンプリングを行うことで強いモデルの1回の解答生成能力を上回れることがわかりました。
実用する上でもう一つ気になる点としては、どの程度コストをかけるとどの程度の性能が得られるのかということです。
本研究では、コストの指標として FLOPs と API料金 を利用して検証しています。
FLOPs
Llama-3系においてFLOPsを用いた検証をしています。
Llama-3系は密なTransformerであり、パラメータの大部分が行列の乗算に使用されるので、FLOPs は以下で近似できるとしています。
以下に、MiniF2F、CodeContests、MATH、GSM8Kそれぞれのタスクでの FLOPs の増加に伴う Coverage の変化の結果を示します。
図4: 4タスクにおける、Llama-3系でのFLOPsの増加に伴うCoverageの変化。(論文より引用)
MiniF2F、MATH、GSM8Kの3タスクでは、8Bモデルの方が70Bモデルよりコスト効率が高く、CodeContestsは70Bモデルの方がコスト効率が高い結果です。
著者らは、FLOPsだけでは、システム効率の他の側面を無視した粗雑なコスト指標になる可能性があると注意書きしています。
APIコスト
SWE-bench Liteのタスクに対して、Claude-3.5 Sonnet と GPT-4o の1回の解答生成コストと、DeepSeek-Coder-V2-Instruct の反復サンプリングのコストを比較します。
結果は以下の通りです。
表1: SWE-bench Liteタスクに対する、Claude-3.5 Sonnet と GPT-4o の1回の解答生成APIコストと、DeepSeek-Coder-V2-Instruct の反復サンプリングのAPIコスト比較。(論文より引用)
DeepSeek 1回の解答生成では GPT-4o や Claude-3.5 Sonnet といった強力なモデルに性能は及びませんが、5回サンプル生成することでIssueの解決率は GPT-4o と Claude 3.5 Sonnet を上回り、全体のコストとしては GPT-4o の3分の一以下、Claude 3.5 Sonnet の4分の一以下となっています。
3. 反復サンプリングの特徴
LLMの損失とその学習計算量との関係は、スケーリング則として特徴づけられています。
本研究では、このスケーリング則に着想を得て、Coverageとサンプル予算との関係を特徴づけるモデル化ができるか調査します。
3.1 反復サンプリングのスケーリング則
GPT-4のテクニカルレポートにある、コーディング問題に対するモデルの平均対数通過率と学習計算量の関係が冪乗則でうまくモデル化できる報告を参考に、Coverage
ここで、
そして、Coverage を直接予測するために以下の式変形をします。
以下に実際の Coverage(青線)と 式の描画(赤線)を並べた図を示します。すると、MiniF2F-MATH(右下)はあまりフィットしていないですが、それ以外ではある程度グラフが重なり、スケーリング則がある可能性が示唆されています。
図5: Coverage とサンプル数の関係の実測値と式描画を並べたもの。(論文より引用)
3.2 モデルファミリー内でのCoverage曲線の類似性
図3左を見ると、同じモデルファミリー内ではCoverage曲線が左右にシフトして入るものの似た傾斜のS字曲線を描いていることが見て取れます。
これをさらに調査するために、同じファミリー内でのCoverage曲線を重ねて図示した結果を以下に示します。
図6: 同じモデルファミリー内のモデルのCoverage曲線を重ね合わせたもの。全ての曲線が点
すると、曲線の類似性が高いことが見てとれ、ここから同じモデルファミリー内では、Coverage
4. Precision向上の必要性
ここまではCoverageに焦点を当ててきましたが、検証は常にモデルが生成したサンプル候補の中から正しい解答を特定できる条件下で行っていました。
ここでは、モデルが生成したサンプル集合の中から正しいものを特定できるかの調査を行います。
4.1 一般的な検証器は生成サンプル数に対して必ずしもスケールしない
今回検証している5タスクのうち、GSM8K と MATH には自動的に解答を検証するツールがありません。
ここでは、解答を決定するためのいくつかの判定アプローチを検証します。
-
多数決
最も生成数が多い解答を採用する。 -
報酬モデル+最高スコア
それぞれの解答に報酬モデル(ArmoRM-Llama3-8B-v0.1)でスコア付し、最高スコアのサンプルを解答とする。 -
報酬モデル+多数決
それぞれの解答に報酬モデルでスコア付し、解答生成数で重み付けした結果での最高スコアのサンプルを解答とする。
以下にそれぞれの検証器による正解判定成功率の変化の結果を示します。
図7: それぞれの検証器による生成サンプル数の増加に伴う正解判定の成功率の変化。青線は最適な検証器があった場合のCoverage。(論文より引用)
3つの方法全てで成功率はサンプル数の増加とともに初めは上昇しますが、100サンプル付近で頭打ちになっています。
多数決の場合、サンプル数が増加すると各解答に割り当てられる票の割合が安定し、それにより成功率が頭打ちになります。
GSM8K や MATH の一部の問題では、正しい解答が1%以下でしか生成されていない問題もあり(図8参照)、そのような問題は多数決では正解を判定できません。
図8: GSM8K と MATH の各問題において、モデルが生成した10,000サンプルの中での正解サンプルの割合(x軸が問題番号、y軸がその問題における正解サンプル率。正解サンプル率でソートしている。)。緑色の棒が多数決で正解をピックできた問題。赤色の棒が多数決で正解をピックできなかった問題。(論文より引用)
報酬モデルによる判定は現状あまり良くなく(図7参照)、検証器の構築が難しいことがわかります。
GSM8K や MATH では、正誤は最終解答のみで判断され途中の思考過程は考慮されませんが、検証器の構築においては思考過程が利用できるかもしれません。
そこで、Llama-3-8B-Instruct が GSM8K の問題に対して正しく解答した105の思考過程を手動で評価し、まぐれで正解したのではなく思考過程も正しい割合を調査しました。
その結果が以下の表2になります。
表2: GSM8K の問題に対するLlama-3-8B-Instructの解答における思考過程(CoT)の妥当性人間評価。問題ごとに3回解答生成を実施し評価。実際の問題ごとの結果はこちら(著者が用意したスプレッドシート)から確認できる。(論文より引用)
90%以上が思考過程も正しいことがわかりました。
ここから、検証器の構築に思考過程を活用できる可能性が示唆されています。
4.2 検証器の注意点
検証器は証明チェッカーのような完全性はなく不完全である可能性が高く、誤って正解と判定したり正解を見逃したりする恐れがあります。
セクション2.1の検証時に遭遇した検証器の不完全さの例が紹介されています。
4.2.1 SWE-bench Lite における不安定なテスト
SWE-bench Lite において解答候補を生成する際、問題の11.3%で同じ解答候補を検証器に通しても検証器の結果が一貫しないことがありました。
また、データセット内の正解の解決策でさえも不正解と判定されることがあったとのことです。
論文内の付録では実際にテストが不安定であった問題のIDが掲載されています(ここでは省略します)。
4.2.2 CodeContests における誤検出
CodeContests では、問題によっては複数の正解が許容されているものの、検証器が特定の解答のみ正解とする場合がありました。
また、CodeContests のテストケースは問題文の変数の値をプログラム的に変更させて自動生成させている場合が多いのですが、その変更が入力仕様を満たしておらず(ex. 正の整数が与えられるべきなのにゼロが与えられているなど)、正しい解答でもテストに通らず不正解とされることがあったとのことです。
5. 議論
反復サンプリングの改善
本研究では、反復サンプリングはシンプルな設計(毎回同じプロンプトとハイパラ)で行いましたが、以下のような改善案が考えられると述べられています。
-
解答の多様性を持たせる
現状、多様性を生み出すメカニズムとしては temperature を正にすることしか採用していませんが、AlphaCodeでやられているような生成毎に異なるメタデータタグを与えるなどの工夫は考えられます。 -
マルチターン
CodeContests や MiniF2F では解答の自動検証ツールがあるため、解答生成後に自動検証ツールを通した結果をモデルにフィードバックし再度解答生成させるマルチターンも有効ではないかと述べられています。一回の最終解答を得るまでのコストは増加しますが、成功可能性も上がると考えられ、このトレードオフにも興味があります。 -
これまでの生成サンプルへのアクセス
現状、各生成は独立で行なっていますが、既存の生成サンプルにアクセスできるとそれ以降の生成に有益な何かがあるかもしれない、と述べられています。
反復サンプリングを推論システムに利用する恩恵
チャットボット(応答の低遅延が重視され、それを達成するためにバッチサイズを小さくしたりハードウェアの利用率を低くする必要がある)とは異なり、反復サンプリングは全体的なスループットとハドーウェアの利用率最大化を重視するのと、系列間での重複を活用する既存のattention最適化の恩恵を受けられるため、単純に多くの並列リクエストをチャットボット向けAPIに送るよりも低コストで実現可能と述べられています。
検証器の改善
検証器の改善は重要な課題であると述べられています。
一つのアプローチとして、非構造化タスクの検証器の開発においては、Leanのような形式言語に変換して証明チェッカーを適用できるようにするコンバーターの設計が考えられるとしています。
お知らせ
少しでも弊社にご興味を持っていただけた方は、お気軽にご連絡頂けますと幸いです。まずはカジュアルにお話を、という形でも、副業を検討したいという形でも歓迎しています。
Discussion