BitNetの計算を検討する
記事作成日: 2024年3月21日
はじめに
こんにちは。松尾研 GENIAC LLM開発プロジェクト、Team JINIAC の佐野敏幸です。
このプロジェクトが始まる直前にBitNetというアーキテクチャが発表されました。(https://arxiv.org/abs/2402.17764 )。この記事では、このモデルを採用するか判断する際に考えたことについて書いています。
はじめに
2024年1月27日に発表された『The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits』(https://arxiv.org/pdf/2402.17764) は衝撃的でした。これまで8bitなどに量子化することはよくありました。8bitの場合、ほぼ連続といっていい数値を2の8乗、つまり256段階に割り振り粗くします(このように数値を段階に割り振ることを「量子化」と言います)。量子化をすることで数値の精度は落ちますが、その分メモリの使用領域が少なくて済むトレードオフがあります。
この論文で提案されている手法は、その量子化をさらに極端に行い、-1,0,1の3段階のみで数値を表します。
このように3段階のみに量子化をすることで、計算上、大きなメリットが生まれます。
通常、行列×ベクトルでは個々の要素は乗算と加減算で計算される。1(.58)-bitの行列を用いると、加減算のみで構成することが可能となります。
https://arxiv.org/pdf/2402.17764 より
通常、行列とベクトルの積を計算するとき、乗算と加減算を行います。
上の図の上段(FP16)の計算だと、Output Yの1つの要素につき、
0.2961 × x0 - 0.0495 × x1 - 0.0924 × x2 - 0.4765 × x3
のような計算をすることになり、これだけ乗算が4回、加減算が3回必要です。
ところが、BitNetの-1,0,1の3つで表す手法では、上の図の下段(1(.58)-bit)に示されている通り、
x0 - x1 - x2 - x3
のような計算になり、乗算がなくなり、加減算を3回行うのみとなっています。
コンピューター内では、加減算よりも乗算の方が負荷が大きいこともあり、BitNetの手法では大幅な計算量削減になります。また、行列のメモリサイズも小さくできるため、より大きなモデルをメモリに載せることが可能になり、モデルの大規模化を助けることにもなります。
なお、今のところ論文で用いたBitNetの実装コードは公開されていない状況です。有志の方々が論文をなぞりながら実装したものをいくつか見ると、BitNetで確かにLossが減少して学習が進んでおり、BitNetでも確かに学習ができることがわかります。
なお、有志による実装でも、フィードフォワードの計算ではtorchなどの通常のライブラリを用いていることから、以下のような計算になっていると考えられます。
1 × x0 + (-1) × x1 + (-1) × x2 + (-1) × x3
この計算では、BitNetでも学習できることの判定はできますが、計算量の削減等のメリットは享受できないものになります。
計算をどうやるか
とはいえ、現実的にどうやるのか、が悩ましいとことでもあります。
Numpyなどの計算ライブラリでは、行列計算は内部で計算方法をかなり工夫してあり、高速に演算できるようになっているらしいです(さらに内部ではC言語で作成されたプログラムを利用しているとも聞きます)。ディープラーニングでよく使用されるpytorch等でも内部の計算はかなり工夫して高速化されているはずです。
条件分岐による計算
自前で計算方法を組み立てるのに、例えば、if 文を用いて、「-1のときは引き算をする、0のときは何もしない、1のときは足し算をする」をすることも可能といえば可能ですが、計算速度はNumpyなどの既存のライブラリによる普通の行列演算よりは相当遅くなると考えられます。
ビット演算子による計算
ビット演算は、2進数に対して操作を加える演算方法です。ビットに対して直接操作を加えることで、通常の加減乗除よりも高速な処理をできる特徴があります。
ビット演算子の「~」はビットの0と1を反転させる演算子です。大雑把にいうと ~x は -(x+1)になります。
x=200
print(~x)
この結果は、-201になります。
なお、BitNetでの計算に用いるには、結局のところ条件分岐による処理が必要になりそうです。「-1のときはビット演算『~』をする」のようになることから、結局処理が重たくなりそうです。
抽出による処理
抽出による処理方法についても検討してみました。
以下のような流れです。
1.58-bitの行列の行から、-1がある位置を抽出し、その位置と対応するベクトルの要素のみを抽出したベクトルを作成し、そのベクトルの要素の和を計算する(A)。次に、1がある位置を抽出し、同様な処理を行う(B)。A-Bを計算結果とする。
なんとなくいけそうな気もしますが、抽出の部分でやはり処理が重たそうな気もします…
まとめ
BitNetの計算量削減のメリットが得られるような計算方法を考えたが、意外と大変そうな印象が得られた。
『The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits』では、BitNetを新たな計算パラダイムの新時代の幕開けとしています。そして、GPUではなく、BitNetの計算に適した新たなハードウェアの登場を示唆しています。
新たなハードウェアを提案するのは、今回検討したように、BitNetのメリットを受けられる実装の難しさにあるのかもしれません。
東京大学 松尾・岩澤研究室が運営する松尾研LLMコミュニティのLLM開発プロジェクト[GENIAC] の開発記録、情報発信になります。 各種リンクはこちら linktr.ee/matsuolab_community
Discussion