🦦

Team「たぬき」開発振り返りメモ3: 10bクラスの大規模言語モデルを実際に開発して想定外だったこと5選

2024/05/28に公開

(本記事は、技術メモをもとに、Claude 3で一般向けに自動作文したものです)

はじめに

我々は最近、大規模な言語モデルの開発プロジェクトに取り組みました。言語モデルとは、大量のテキストデータを学習することで、人間のような自然な文章を生成できるAIシステムのことです。プロジェクトでは、250回以上の試行錯誤を繰り返しながら、モデルの性能を少しずつ向上させていきました。ここでは、その過程で直面した問題と解決策について報告したいと思います。

1. GPUの性能を全く引き出せなくて焦る

プロジェクトでは、NVIDIA社の最新鋭GPU「H100」を24枚も使用することになりました。H100の価格は1枚あたり約500万円と非常に高価ですが、前世代のA100と比べて約3倍の性能を発揮するとの噂でした。そこで、H100の高い性能を前提に計算量を設定し、モデルの学習を開始しました。
まず、8枚のH100を1つのサーバー(ノード)に搭載し、練習として学習を行ってみました。すると、280TFLOPS程度の計算速度が得られました。TFLOPSとは、1秒間に何兆回の浮動小数点演算ができるかを表す指標です。別に試したA100を8枚搭載した条件では、100-150TFLOPS程度でしたので、3倍とは行きませんでしたが、それでも及第点と言えるパフォーマンスでした。

本番の学習では、H100を24枚搭載した3台のサーバーを使用することになりました。ここで問題が発生しました。なんと、100TFLOPSを切るという絶望的な結果が出てしまったのです。これは、前世代のA100よりも低い計算速度です。
我々は、計算速度を向上させるため、pipeline parallel、model parallel、data parallel、マイクロバッチサイズ、ZEROなど、様々なハイパーパラメータ(モデルの学習を制御するための設定値)について勉強し直し、最適化を試みました。しかし、それでも120TFLOPS程度の速度しか出ませんでした。
このとき、我々はかなりのプレッシャーを感じました。最高峰の計算資源を使わせてもらっているのに、それを全然活用できていない自分たちが情けなく思えました。問題が解決できないまま時間だけが過ぎていく中で、周りから非難の声が聞こえてくるような錯覚さえしました。

問題解決のための試行錯誤

我々は、計算速度が上がらない原因を必死に考えました。3台のサーバーを使った分散学習では、サーバー間のデータのやり取りがボトルネックになっている可能性が高いと推測しました。
そこで、サーバー間通信を高速化する技術である「DirectGPU TCPX」が正しく動作しているか確認してみましたが、特に問題は見当たりませんでした。
窮余の策として、モデルを3つに分割し、それぞれを別々のサーバーで学習させる「Branch Train Merge」という手法の採用も検討しました。これは、分散学習における通信のオーバーヘッドを減らすための、少し変わったアプローチです。しかし、そのためには新たなデータセットを用意する必要があり、かなりの手間がかかります。効果も不透明だったので、これは見送ることにしました。

突破口はgradient accumulationにあり

行き詰まっていた我々に、松尾研究室のメンバーからアドバイスをいただきました。それは、「Global batchを1536、microを1に設定してみてはどうか」というものでした。
バッチサイズとは、一度に処理するデータの数のことです。Global batchは、複数のサーバーを合わせた全体でのバッチサイズを指します。一方、microは、各サーバーが一回で処理するデータの数です。
それまで我々は、GPUのメモリ容量の制限から、Global batchを100以下に設定していました。1536などという大きな値は、メモリオーバーを引き起こして計算できないだろうと考えていたのです。
ここで登場したのが「gradient accumulation」という技術でした。これは、勾配(モデルの重みを更新するための値)をバッチ全体で一気に計算するのではなく、小さな単位に分けて少しずつ計算し、最後に平均を取る手法です。
通常、gradient accumulationは、バッチサイズを大きくしたい場合に、メモリ使用量を抑えるために用いられます。一方で、複数回に分けて勾配を計算するため、トータルの計算時間は長くなってしまいます。
ですので、計算速度の向上を目指していた我々には、これを使うという発想はありませんでした。ところが、いざ試してみると、驚くべき結果が得られたのです。
FLOPSが一気に300を超え、最適化を進めた結果、最終的には450以上にまで到達しました。
恐らく、gradient accumulationによってサーバー間通信の頻度が減り、ボトルネックが解消されたのだと思います。
(蛇足ながら、GPT-2系のアーキテクチャよりも、LLaMa系の方がFLOPSが高くなる傾向も見られました。)
こうして、我々には到底思いつかなかったパラメータ設定により、計算速度の問題を解決することができました。アドバイスをくださった松尾研究室の方々には心から感謝しています。同時に、分散学習のプロフェッショナルの技術力の高さを思い知らされた出来事でもありました。

2. nanエラーによる学習の強制終了

大規模言語モデルの学習では、「loss spike」と「nan error」が大きな問題として知られています。
loss spikeとは、学習が順調に進んでいると思われた矢先に、突然lossの値が急上昇し、下がらなくなる現象です。lossは、モデルの予測と正解の間の誤差を表す指標で、通常は学習が進むにつれて徐々に減少していきます。
一方、nan errorは、勾配の値が発散して「nan」(not a number、数値ではない)になり、学習の継続が不可能になる現象です。
我々はこれらの現象に、実に10回ほど遭遇しました。夜中や明け方に起きることもしばしばで、放置するとGPUリソースを無駄遣いしてしまいます。
そのため学習中は、「トラブルが起きませんように」と祈りつつ、24時間体制でモデルを監視し、異常があればすぐに対処する必要がありました。

精度と速度のトレードオフ

loss spikeやnan errorの原因は複合的ですが、計算速度の向上のために用いられる「fp16」というデータ形式が問題視されることが多いです。
コンピュータは、数値を2進数で扱います。よく使われるデータ形式としては、32ビットの「fp32」と、16ビットの「fp16」があります。ビット数が少ない方が、メモリ使用量や計算時間を節約できます。
ただし、fp16は精度が低く、指数部のレンジが狭いため、値が無限大に発散してnan errorが起きやすいのです。
対策としては、fp16よりもレンジの広いbf16やfp32を使うことが有効とされています。しかし、我々の場合はfp16を使わざるを得ない事情がいくつかありました。
例えば、学習に使ったフレームワーク(Megatron-DeepSpeed)との相性が悪かったり、バージョンによって、fp16でないと動かない機能があったりしたのです。fp32を使えば確実に安定しますが、計算効率が落ちてしまうというジレンマもありました。
開発にあたっては、ライブラリの仕様まで精通している必要があることを実感しました。

3. 学習データと実力のギャップ

我々は英語学習データとして、学術論文、コード、Wikipediaなど、質の高い文章を大量に集めました。特に数学の論文もたくさん含まれていたので、ある程度は数式が扱えるモデルが出来上がるだろうと期待していました。
ところが蓋を開けてみると、モデルは1+1はできるものの、1+2+3などの簡単な足し算すら満足にこなせませんでした。この結果には、さすがに落胆を禁じ得ませんでした。
どうやら、論文を読ませるだけでは数学的な能力は身につかないようです。四則演算をマスターさせるには、もっと特化した演習用のデータセットを用意する必要があるのかもしれません。

4. トークナイザーの罠

モデルの出力に改行が一切含まれていないことに気づいたのは、学習が終わった後のことでした。もしかしたら、改行コードを無視した状態で学習を進めてしまったのでは?そう考えただけで、頭が真っ白になりそうでした。
チームで冷静に対処できたことは、本当に幸いでした。一人だったら、後悔と自責の念に囚われて身動きが取れなくなっていたかもしれません。
色々と調べた結果、原因はHugging Faceというライブラリに付属する、トークナイザーと呼ばれるプログラムの不具合だとわかりました。
大規模言語モデルでは、単語をそのままの形で扱うのではなく、部分文字列に分割した「トークン」に変換して処理します。この変換を担うのがトークナイザーです。
我々が使った古いバージョンのトークナイザーには、改行コードを正しく処理できないバグがあったようなのです。
結局、llm-jpで公開されていた新しいトークナイザーに乗り換えることで、問題を解消することができました。

[参考画像]

5. ファインチューニング中に想定外のエラー

事前学習とは異なり、ファインチューニングではHuggingFaceの一般的なライブラリを使用しました。これらのライブラリは、私たち自身が日常的に利用しており、ユーザー数も多く、完成度が高いため、大きな問題は発生しないと予想していました。ところが、実際にファインチューニングを行ってみると、「ディスク容量が不足しています(no space left on device)」というエラーが頻発したのです。ディスク容量自体は数TB以上の余裕があるはずなのに、なぜかこのエラーが発生してしまいました。
調査の結果、このエラーはプログラムの実行中に大量の一時ファイルが生成されることが原因であることがわかりました。定期的にこれらのファイルを削除する必要がありますが、削除の方法やファイルの権限についても工夫が必要でした。以下のようなコマンドを使って、一時ファイルを削除することで、問題を解決することができました。

find /var/tmp -maxdepth 1 -user (ユーザー名) -print0 | xargs -0 rm -r

このようなエラーは、大規模な言語モデルの学習においては珍しくありません。想定外の問題が発生することを前提に、柔軟に対応することが重要だと実感しました。

おわりに

大規模言語モデルの学習は、技術的にもメンタル的にもかなりの難易度でした。計算リソースの無駄遣いを恐れ、睡眠時間を削ってモデルの面倒を見続けた日々は、正直言って辛かったです。
それでも、壁を一つ一つ乗り越えていくたびに、大きな喜びと達成感を味わうことができました。GPUの性能を最大限に引き出すための知恵も、少しずつ身についてきたように感じています。
この経験を通して、我々はAI開発の面白さに気づくとともに、最先端の技術に携わるエンジニアの尊さを実感しました。道のりは平坦ではありませんが、これからも新しいモデルの可能性を追求し続けていきたいと思います。

この成果は、NEDO(国立研究開発法人新エネルギー・産業技術総合開発機構)の助成事業「ポスト5G情報通信システム基盤強化研究開発事業」(JPNP20017)の結果得られたものです。

東大松尾・岩澤研究室 | LLM開発 プロジェクト[GENIAC]

Discussion