初心者でもわかる、gpt-o1の仕組みを再現する方法の解説
この記事について
近年、openaiのgpt-o1による高い推論能力に驚かされた方も多い一方、商用gptはclosedな開発がなされていることが多く、中身はブラックボックスでした。この現状の良し悪しは置いておいて、研究はopenであるべきであるという考えの方々もおり、なんとかしてo1を再現する研究が行われていたようです。
Alibabaの出したmarco-o1と呼ばれるモデルは、o1のような高い推論能力を獲得したモデルだそうで、頭の良い方が記事を沢山出しているので、詳しく知りたい方はそちらを参照してください。自分は、記事を読んでもイマイチ理解しきれないところがありましたので、論文や公開されていたコードを見たりして、自分なりの理解ができましたので共有いたします。
LLMについて
LLMもあまりよく分からないという方向けで、簡単にLLMについて説明します。Large Language Modelの略で、でかい言語モデルみたいなイメージで十分です。重要なのはLLMに何を入力して、何が出力されるのかということなので、そちらを見ていきましょう。
LLMへの入力
inputには3種類あり、system, user, assistantと分けられます。それぞれ、
- system
- LLMに持たせる役割(例:あなたは有名な医者であり、患者の症状を聞き、診断をします。)
- user
- ユーザー(私たち)が入力する言葉(例:昨日転んでから、足首が痛いです。)
- assistant(過去のLLMの回答の履歴)
- LLMの回答(例:捻挫だと思われます。冷やして、固定して安静にしてください。)
- 出力じゃないの?と思うかもしれませんが、出力でもありますが、ここではこれまでLLMが出してくれた回答の履歴という意味で入力の1種としています。
そして、これらの会話はhistory(履歴)として保存されて、ユーザーが新しい書き込みをしたときに、一緒にLLMに入力されます。ChatGPTを使ったことがある方はわかると思いますが、今までの会話の流れを覚えてくれていますよね。これは、今までの会話の履歴をuser/assistantでラベルをつけて毎回GPTに入力しているというわけです。
LLMの出力
では、LLMは何を出力するのでしょう?もちろん最終結果は回答となるassistantの出力なんですが、そこに至るまでの過程を簡単に見ていきましょう。具体的な例を見ると分かりやすいかもしれないです。
(例)
ユーザー:明日の天気は?
このように聞いた時、「ユーザー:明日の天気は?」がLLMに入力されます。
出力は実は1ワード(1トークン)で
(回答)明日
のようになります。次に、「ユーザー:明日の天気は?アシスタント:明日」がLLMに入力されます。
出力は
(回答)の
となり、次に、「ユーザー:明日の天気は?アシスタント:明日の」がLLMに入力され、最終的には
「ユーザー:明日の天気は?アシスタント:明日の天気は晴れです!」みたいになるんですね。
ワードの選び方
では、「明日」や「の」はどのように選択されているのでしょうか?ここの選択方法は1通りの正しいやり方があるわけではないと思いますが、一例として一番シンプルな方法を紹介します。今、LLMが「明日」まで出力したとします。すると、次に「明日」に続く文字を考えるわけですが、以下の候補が上位に出てきました。
- 「に」
- 「の」
- 「は」
そして、それぞれの文字にはスコアがLLMによって付与されます。
- 「に」:0.1
- 「の」:0.5
- 「は」:0.4
このスコアは合計すると1になるので確率値とも呼ばれますが、要するにLLMがどれが次に来る言葉にふさわしいかを自分でスコア付けしたものと考えてください。この場合ですと「の」が0.5なのでこれが選択されて「明日の」と続いていくことになります。ここはとても重要ですので、しっかり理解してください。
Reasoningタスク
続いて、Reasoningタスクについて見ていきましょう。こちらも初学者向けです。Reasoningタスクとは、ロジックが必要な難しい問題をLLMに解いてもらうことを指します。例えば、
「ユーザー:今私はリンゴを8つ、みかんを5つ持っていて、弟は私のリンゴの数より4つ少ないみかんを持っており、その後私はみかんを二つたべ、弟はみかんを1つ食べました。母は、現在の私のみかんの数を3倍にした数から、弟の現在のみかんの数を2倍した数の差だけリンゴを持っていることがわかりました。では母は何個りんごを持っているでしょうか?」
のような面倒くさい問題を解かせることを考えます。答えは3かな、多分笑。では、LLMにこれを入力したときに、どういった回答が来るのが望ましいでしょうか?「3」と回答してくれるのが良いのでしょうか?普通はそうですよね?でも先ほどの単語の選び方を理解していると、「3」とだけ回答するのが良いことではないとわかると思います。
どういうことかというと、もしLLMの出力の候補が以下のようだった場合
- 「3」:0.3
- 「2」:0.25
- 「1」:0.25
- 「0」:0.2
なんかあんまり自信がなさそうに見えますよね?「3」であっているんだけど、なんか適当に答えてない?みたいな。実はこれは的を射た指摘であり、適当に答えてたまたま当たってるだけなんじゃないかという疑念が晴れないのです。
では、どうなるのが望ましいのでしょうか?実はReasoningというので、「3」という回答を出すまでの理由付けも一緒にLLMが出力してくれると望ましいのです。つまり、解答例としては
「弟は初め、あなたの持っているりんご8個より4つ少ないみかんを持っていたので、8−4=4により、4個のみかんを持っていたことになります。続いてあなたは2つのみかんを食べたので、5−2=3により、3つのみかんになり、弟は4−1=3により、3つのみかんを所持することになります。最終的に、母はあなたが持っているみかんの3倍、つまり3✖️3=9から、弟の持つみかんの数の2倍3✖️2=6を引いた数9−6=3個のリンゴを持っていることになります。したがって答えは3です。」
のような回答を出してくれると3に辿り着くまでの過程が明らかになり、しかも最後の「3です」の3がLLMによって出力される時の数値も
- 「3」:0.999
- 「2」:0.0004
- 「1」:0.0003
- 「0」:0.0003
のように「3」が圧倒的に高いスコアになってくれるだろうということですね。
まとめると、LLMが自分なりにロジックを持って問題解決ができると嬉しく、そういったロジックが必要なタスクをReasoningタスクとよび、その際に、LLMの予測スコアを見るとそこにもロジックが反映されているということです。
Chain of Thought(COT)
では、どのようにしてそのようなロジックをLLMが獲得するのでしょうか?通常のLLMは大量のWebデータや会話データなどから対話能力(昔の話を覚えながら、普通に人と会話する能力、知識を覚えてそれを話す)を持っていますが、難しいロジックをそのまま解くことは苦手です。
こういった能力の獲得のため、instruction-tuningと呼ばれる追加の学習が行われます。その中でも特にChain of Thought(COT)について話していきます。
そもそもCOTって?
Chain of Thoughtとは文字通り「思考の連鎖」のことです。まあそれだけ言われてもって感じだと思うので具体例で一つずつ見ていきましょう。実は最初は追加の学習とかではなく、LLMの入力(プロンプト)を改善するためのものを指していました。どういうことかというと、例えば
LLMに対して
「(3+4)✖️5を計算してください。」
とだけ入力するのではなく、
「(3+4)✖️5を計算してください。ただし、計算手順を一つずつ実行し、なぜその手順を最初に実行したかを述べ、最終的な結論を書いてください。」
とすると単純に答えを出力するだけでなく、過程や、より正確な回答を出力してくれるというテクニックのことでした。
COTの学習への活用
このCOTを学習に活用しようというのが、まずo1風のモデルを作るための第一歩です。今までは入力(プロンプト)上でCOTをしていたのですが、学習データにも入れて再学習しようとなりました。具体例を見ましょう。
Question:(3+4)✖️5を計算してください。
Answer:<Thought>()は数学では最初に計算するべきというルールがあります。そのためまず(3+4)を計算します。その結果3+4=7となりました。\n\n続いて、掛け算の処理に移ります。7✖️5=35により答えは35だと思われます。\n\n</Thought><Output>35</Output>
のようなデータを作ります。機械学習はデータを元に学習されるので、この場合ですとQuestionをLLMに入れて、Answerのような文章を作れたらOK、そうでない場合はペナルティを与えるみたいな学習をしていきます。
<Thought>とは何なのでしょうか?突然謎の<>の記号が出てきましたが、これはLLMが「<Thought>で囲まれたところは思考部分である」と自分で理解するため、また私たちが「これはLLMの思考部分」と明示的に理解するためのものです。このような区切りをつけておくことで後々どこまでが思考で、どこが結論なのか扱いやすくなるのです。また\n\nというのも気になりますね。\nは改行コードですので、\n一つで、改行1行を表します。\n\nは2行分ですね。つまり、空の行が1行できるので、これはThoughtの中でも思考のブロックを作り分離しているということになります。学習データにこのように思考をブロック(ステップ)ごとに分けることで、実際に推論(学習した後に使うとき)の際でも同じように思考をステップごとに\n\nで区切ってくれます。
Monte Carlo Tree Search(MCTS)
少し話がずれますが、こちらも重要な話ですので、わからないと言う方は必ず読んでください。
モンテカルロ木探索と日本語だと呼ばれるのかと思います。こちらは何かというと、将棋AIなどで活用される、最適手の探索方法になります。なんで突然将棋?ってなるかと思いますが、これがo1を理解する上で必要なことになるのです。
MCTSの4ステップ
MCTSは4ステップから構成されます
- 選択(Selection)
- 展開(Expansion)
- シミュレーション(Simulation)
- 逆伝播(Back propagation)
ここでは詳しい解説は致しませんが、大雑把な理解で十分だと思います。以下の図を見ながら、説明を読んでください。
選択(Selection)
選択とは、どの手を深く分析するかを選ぶ作業です。将棋ですと、
- 勝率が高そうな手
- まだあまり探索されていない未知の手
を探索していくのが良さそうですよね
展開(Expansion)
展開とは、Selectionで選んだ良い盤面や未知の盤面から、新たに1手を差して変更した盤面を、分析ツリーに追加する作業です。初期状態では、将棋ですと駒を初期配置に並べた盤面のみが分析ツリーに追加されている状態です。
シミュレーション(Simulation)
シミュレーションとは、Selectionで選んだ盤面から「完全にランダムで試合を続ける」作業のことを言います。要するに良い盤面だとそこからランダムに打っても結果は良くなるはずだよね?と言う意味でしょう。
シミュレーションと展開どっちを実行するの?
Selection後には展開とシミュレーションの2択処理があるのですが、これは一定回数シミュレーションがされていたら展開、まだそこまでシミュレーションしていなかったらシミュレーションという風に実行を制御します。
逆伝播(Back Propagation)
逆伝播はシミュレーションの結果を踏まえて、これまでの分析ツリーを更新していく作業で、具体的には、盤面に対して、「勝ち数/試行回数」を記録していく作業になります。
marco-o1の学習について
さて、ここまでがなんとなく分かればmarco-o1も理解できるはずです。
marco-o1って何?
marco-o1はgpt-o1を目指した研究により開発された公開モデルです。公開されているため誰でも利用可能です(gpuなどは必要)。macro-o1はLLMにCOTを用いた追加学習(fine-tuning)を行い、さらに推論時にMCTSの仕組みを用いた推論ステップを踏むことで、o1のようなロジックを用いた深い推論が可能になりました。
論文を見てみよう
論文に描かれていた図を持ってきました。
この図、なんかさっきも似たようなの出てきましたよね?そうですMCTSです。そして、今回はわかりやすく色付けがされていますよね?それぞれのノードが右の四角で囲まれた部分に対応しています。
最初の青で囲まれた
「Prompt: How many 'r' in strawberry?」
というのは、ユーザーからの質問ですね。そのためPromptと書かれているのがわかると思います。
続く、紫、黄色、ピンク、緑の部分が推論のそれぞれのステップです。そして、最後がLLMが出した答えになります。
COT fine-tuning
なんとなくMCTSの適用が見えてきたところで、COTに戻りましょう。MCTSを適用するには、推論のステップと答えが必要でした。何もしていないLLMにはそのような推論ステップを出力することは難しいです。そのため、COTデータを作り、再学習させる(fine-tuning)ことでLLMに対して、推論ステップと答えを出すようにさせます。
本論文では、openaiのcotデータと、自作のcotデータと、instructionデータ(指示文と回答のペア、<Thought>などのタグはない)で再学習しています。これにより、LLMに対して論理的な解答を出力するように矯正していきます。この状態をmarco-cotと呼びましょう。
MCTS+marco-cot
ここでは、MCTSをどのようにmarco-cotに適用していくかを見ていきます。marco-cotはユーザーからの入力に対して、<Thought>タグから始まる自分の思考を考えていきます。一つの思考ステップは\n\nで区切られていたと思います。学習データがそのような構成になっていたので、学習後もmarco-cotはそのように思考を改行二つで区切ってくれるでしょう。すると、私たちは、モデルの出力を1ワード(1トークン)ずつ得られるので、\n\nが来た時点で一度LLMへの入力をストップできます。
これを繰り返すことで、最初のプロンプトからいくつかのステップを作ることができます。
以下のようになるはずです。
では、その次にどこを見ていくかを選択(Selection)していかなければならないのですが、将棋の例ですと、勝率や探索回数などの指標で決めることができました。この論文では、指標として、
「出力トークンの確率値の平均値」
を用いているそうです。どういうことかというと、LLMが1ワードずつ生成し、そこに確率の値がついていることはわかっていると思います。また、ワードを選ぶ際に他にも候補があることもわかっていると思います。
今、LLMがiワード目を考えているとした時、iワード目の候補からtop5を取り出してきて、それぞれの確率値をp1,p2,p3,p4,p5とします。それを用いて以下を計算したものがiワード目の点数と定義しています。
そして、さらにLLMが\n\nを出力するまでの1ステップまでに出力したワード数をnとしたときに、iが1からnまでの平均値を取れば、これが1思考ステップに対してついたスコアになります。つまり以下のようになりますね。
このようにすることで、1思考ステップに対して勝率のような点数がつけれるので、MCTSのSelectionができるというわけです。(探索回数などもスコアに入れていると思いますが、そちらはMCTSを深掘りしてみてください。)
では、このスコアにはどのような意味があるのでしょうか?簡単にいうと、LLMがどれだけ自信を持って答えられているか?というスコアになります。つまりMCTSではLLMがより自信を持って答えている方向に進んでいこうという探索をしているということです。
この操作を続けることで、</Output>が出てくるまで、探索を繰り返し、最終的な答えを出します。
おまけ
論文の本質的なところは概ねここまででお話しできたと思います。おまけでは、\n\n以外のステップの切り方やwait文を入れることでの自己反省について書いていきます。
mini-step
\n\n以外にも出力されたトークン数を記録しておき、一定まで達した時点で1ステップとするみたいなこともしているそうです。こちらは言語などによって性質が違うとのこと
wait文
1ステップの後にLLMに「ちょっと待って!これまでの推論は間違っている可能性があるから考え直そう!」みたいな文章を入れると自分で振り返り推論が正しいか間違っているかを振り返るとのこと。具体的な実装を見ているわけではないので、それ以上はよくわからず。
Discussion