bitsandbytes で NF4 量子化の保存・読み込みを実装する
はじめに
近年、機械学習モデルはスケールする傾向があり、それらの学習コストも同時に大きくなっています。大きいモデルを学習・推論するためにはより多くの VRAM が要求されることになるので、効率的に処理を行うためにさまざまな研究が行われています。
効率化の方法の一つとして、重みの量子化をすることでメモリ消費を節約し、学習・推論の効率化 をするものがあります。
今回は bitsandbytes ライブラリを用いて、モデルを NF4 量子化し、その重みを safetensors 形式で保存、また、保存した重みを読み込む方法を紹介します。
bitsandbytes とは
重みを量子化する方法は AutoGPTQ、optimum-quanto、torchao がありますが、今回は bitsandbytes を使います。
bitsandbytes はいつの間にか HuggingFace の傘下になっていたようで、現状は内容が少し薄いですがドキュメントが huggingface の元で公開されています。
bitsandbytes は他のライブラリと同様、8 bit の量子化や 4 bit の量子化を行うことができます。特に、QLoRA[1] で導入された NormalFloat4 (NF4) を手軽に扱うことができます。詳しくないので細かいことはわからないのですが、NF4 は正規分布の特性を利用して誤差を減らし、通常の 4 bit 浮動小数展で学習した時よりも性能が向上するそうです。
bitsandbytes 自体は pip install bitsandbytes
でインストール可能ですが、今回行う NF4 量子化は現時点(2024/12/28)で CUDA のみのサポートになるため、Colab T4 等の CUDA が使える環境が必要に なります。
環境の用意
CUDA が使える環境で必要な以下をインストールします。
- torch
- bitsandbytes
- safetensors
pip install torch bitsandbytes safetensors
モデル定義
サンプルとして、非常に簡単な構造をしたモデルを定義します。
import torch
import torch.nn as nn
class SmallModel(nn.Module):
def __init__(self):
super().__init__()
# 量子化対象の層
self.linear = nn.Linear(128, 256, bias=True, dtype=torch.float16)
# 全然関係ない層
self.other = nn.Parameter(torch.randn(64, 256, dtype=torch.float16))
# 適当に初期化
nn.init.xavier_uniform_(self.linear.weight)
nn.init.xavier_uniform_(self.other)
def forward(self, x):
return self.linear(x) + self.other
model = SmallModel()
model
は以下のようになります。
SmallModel(
(linear): Linear(in_features=128, out_features=256, bias=True)
)
今回は線形層である model.linear
を NF4 量子化します。
通常の重みの保存
後で比較する用に通常の重みを保存しておきます。
from safetensors.torch import save_file
save_file(model.state_dict(), "fp16.safetensors")
NF4 量子化をする
今回はインスタンス化されたモデルではなく、その state_dict
を元に NF4 量子化を行います。この方法であれば、いちいちインスタンス化しなくても .safetensors
ファイル等の 重みのファイルさえあれば量子化が行えます 。
また、今回は .weights
を持つ線形層のみを対象に行います。
from bitsandbytes.functional import quantize_nf4
for key, value in state_dict.copy().items():
if key.endswith(".weight"):
q_tensor, state = quantize_nf4(state_dict[key].to("cuda:0"))
state_dict[key] = q_tensor.to("cpu")
for k, v in state.as_dict(packed=True).items():
state_dict[key + "." + k] = v
bitsandbytes.functional
から quantize_nf4
をインポートして、キー名が .weight
で終わる value
を量子化しています。その際、ライブラリの関係で .to("cuda:0")
で CUDA に転送してから渡す 必要があります。
quantize_nf4
は Tuple[torch.Tensor, QuantState]
を返します。torch.Tensor
の方は量子化された重みそのものですが、QuantState
は量子化方法に関する情報 を持っており、インスタンス化や量子化を元に戻すときなどに必要になります。
そのため、量子化されたテンソル q_tensor
に加え、state
も追加で state_dict
に保存しています。この保存方法は bitsandbytes の Linear4bit
クラス等で行われるもの と同じです。
量子化された重みの保存
量子化処理が終わったら、float16
のときと同様に保存できます。
save_file(state_dict, "nf4.safetensors")
レイヤー情報の確認
さて、量子化された重みがどのように保存されているのかを確認してみます。safemetadata を使って nf4.safetensors
のレイヤー情報を表示します。
❯ safemetadata layers nf4.safetensors
╭───────────────────────────────────────────────┬───────────┬─────────────╮
│ Parameter Name │ DType │ Shape │
├───────────────────────────────────────────────┼───────────┼─────────────┤
│ linear.bias │ float16 │ [256] │
│ linear.weight │ uint8 │ [16384, 1] │
│ linear.weight.absmax │ float32 │ [512] │
│ linear.weight.quant_map │ float32 │ [16] │
│ linear.weight.quant_state.bitsandbytes__nf4 │ uint8 │ [79] │
│ other │ float16 │ [64, 256] │
╰───────────────────────────────────────────────┴───────────┴─────────────╯
量子化した linear.weight
は uin8
型で [16384, 1]
の形状で保存されているようです。型が NF4 じゃないし、形状も元モデルと全然違う!?って思ってしまいそうですが問題はありません。 NF4 は通常の型ではないので、バイナリとして情報を保存しているのだと思われます。[2]
また、absmax
や quant_map
などの情報も保存されています。quant_state.bitsandbytes__nf4
から NF4 で量子化したという情報も確認できますね。これらは重みを読み込むときに使うことになります。
バイアスの linear.bias
や線形層ではない other
レイヤーはそのまま float16
で保存されています。
通常の重みの読み込みの比較
比較用に、通常の重みの読み込みをするコードも載せておきます。
from safetensors.torch import load_file
model = SmallModel()
model.load_state_dict(load_file("fp16.safetensors"), assign=True)
量子化された重みの読み込み
こちらは少し面倒な処理が入ります。
bnb.nn.LinearNF4 の置き換え
まず、量子化された線形層を扱うには bitsandbytes
に収録されている LinearNF4
を使う必要があります。しかし、現在定義している SmallModel
では PyTorch の nn.Linear
を使っているので、そのままでは互換性がなく、読み込むことができません。そこで、最初にモデルの nn.Linear
を LinearNF4
に置き換えることを行う必要があります:
import torch.nn as nn
import bitsandbytes as bnb
def replace_quantized_linear(
model: nn.Module,
state_dict: dict[str, torch.Tensor],
parent_param_name: str = ""
):
for name, module in model.named_children():
if isinstance(module, nn.Linear):
q_layer = bnb.nn.LinearNF4(
module.in_features,
module.out_features,
bias=(module.bias is not None)
)
param_name = parent_param_name + name + ".weight"
weight = state_dict[param_name]
stats = {
k[len(param_name + "."):]: v
for k, v in q_state_dict.items()
if k.startswith(param_name + ".")
}
quant_type = [
k
for k in q_state_dict.keys()
if k.startswith(param_name + ".quant_state.")
][0][len(param_name + ".quant_state.bitsandbytes__"):]
q_layer.weight = bnb.nn.Params4bit.from_prequantized(
data=weight,
quantized_stats=stats,
quant_type=quant_type,
)
setattr(model, name, q_layer)
else:
replace_quantized_linear(module, state_dict, parent_param_name + name + ".")
replace_quantized_linear()
では、渡された model
の named_children
(持ってるレイヤー) をチェックし、そのモジュールが nn.Linear
であれば、bnb.nn.LinearNF4
に置き換え、そうでなかったらそのモジュールの中も再帰的にチェックし置き換える、というのを行っています。
bnb.nn.LinearNF4
は元の nn.Linear
と同じ形状になるようにし、q_linear
という名前で定義しています。読み込む際に、量子化情報も取得する必要 があるため、その辺りを stats
、quant_type
で取得しています。今回は NF4 ですが、先ほど確認したレイヤー情報から量子化手法も判定できるので、quant_type
で分岐すれば bitsandbytes でサポートされている他のレイヤーを使うこともできます。
bnb.nn.LinearNF4
の weight
は通常の nn.Parameters
ではなく bnb.nn.Params4bit
で保存するため、事前量子化済みの重みを読み込むための関数である from_prequantized
に重みと量子化情報を渡して、重みを読み込んでいます。
今回のようなシンプルな構造では直接 model.linear
だけを置き換えればいいかもしれませんが、これを大規模なモデルに適用したい時を考えると、特定のモデル構造に依存しないほうが汎用性があると思います。
他の重みも読み込み
線形層 weight
だけ読み込みましたが、今回の線形層には bias
があるほか、other
レイヤーをまだ読み込んでいません。
量子化を行なっていないレイヤーの重みを読み込む処理を行う必要があります。
def load_normal_weights(model: nn.Module, state_dict: dict[str, torch.Tensor]):
has_quant_state_keys = [
k.split(".quant_state.")[0]
for k in state_dict.keys()
if ".quant_state." in k
]
excluded_quant_state = {
k: v
for k, v in state_dict.items()
for quant_key in has_quant_state_keys
if not k.startswith(quant_key)
}
model.load_state_dict(excluded_quant_state, assign=True, strict=False)
量子化情報を持つレイヤーのキーを除外してから load_state_dict
で読み込みます。その際、キーが不足することになるため strict=False
を指定しています。
読み込み実行
重みの読み込みのための関数を実装できたので、実際に読み込みます:
q_model = SmallModel()
q_state_dict = load_file("nf4.safetensors")
replace_quantized_linear(q_model, q_state_dict) # レイヤー置き換え
load_normal_weights(q_model, q_state_dict) # 他の重み読み込み
# 量子化してないレイヤーは一致するはず
assert torch.allclose(model.other, q_model.other)
linear
が LinearNF4
になっていることが確認できると思います。
print(q_model)
---
SmallModel(
(linear): LinearNF4(in_features=128, out_features=256, bias=True)
)
実行
何か適当なテンソルを通して出力結果を比較してみます。NF4 量子化は CUDA で実行する必要があります。また、型キャストを自動でやってもらうために torch.autocast
しています。
import torch.nn.functional as F
inputs = torch.randn(1, 128, dtype=torch.float16, device="cuda")
# 量子化してない方
model = model.to("cuda")
with torch.no_grad():
with torch.autocast(device_type="cuda", dtype=torch.float16):
print(model(inputs)[0, :10])
# tensor([-0.4531, -0.5933, 0.2136, 0.3015, -0.1589, 0.1306, 1.3281, -0.2203,
# -0.0833, -0.1528], device='cuda:0', dtype=torch.float16)
# 量子化した方
q_model = q_model.cuda()
with torch.no_grad():
with torch.autocast(device_type="cuda", dtype=torch.float16):
print(q_model(inputs)[0, :10])
# tensor([-0.4873, -0.6221, 0.2023, 0.4419, -0.2456, 0.1458, 1.2314, -0.2571,
# 0.0250, 0.0178], device='cuda:0', dtype=torch.float16)
# 平均二乗誤差
with torch.no_grad():
with torch.autocast(device_type="cuda", dtype=torch.float16):
print(F.mse_loss(model(inputs), q_model(inputs)))
# tensor(0.0052, device='cuda:0')
完全一致はしてませんが、大体近い値になっているのがわかります。bnb.nn.LinearNF4
が NF4 の計算をよしなにやってくれるため、実際のレイヤーを意識せずに透過的に扱うことができます。
torch.compile
triton
が必要になりますが、 torch.compile
にも対応しています。
pip install triton
q_model = torch.compile(q_model, mode="max-autotune")
with torch.no_grad():
with torch.autocast(device_type="cuda", dtype=torch.float16):
print(q_model(inputs)[0, :10])
# tensor([-0.4873, -0.6221, 0.2023, 0.4419, -0.2456, 0.1458, 1.2314, -0.2571,
# 0.0250, 0.0178], device='cuda:0', dtype=torch.float16)
終わり
bitsandbytes を使った NF4 量子化の方法と、量子化した重みの保存・読み込み方法を紹介しました。
巨大モデルを NF4 量子化して効率的に動かしましょう!!
余談
ライブラリの実装や公式ドキュメントを読むよりも、実際の使用例を見たほうが理解しやすいことがあると思いますが、そのような時は grep.app を使うと便利です。grep.app は GitHub で公開されているコードを爆速で検索できます。
Params4bit
が使われている例の検索結果:
Discussion