TanukiモデルのAWQ、GPTQ、GGUF量子化について
GENIAC 松尾研LLM開発プロジェクトメンバーのArataです。
本記事では、Tanuki-8BとTanuki-8x8Bの各種手法による量子化についてまとめます。
はじめに
GENIAC 松尾研LLM開発プロジェクトでは、Tanuki-8BおよびTanuki-8x8Bという2つのモデルを開発しました。以下に概要を書きます。
-
Tanuki-8B
- モデル構造:Llama-3 8Bと同一の構造
- トークナイザー:llm-jp tokenizer ver2.1を参考に作成
-
Tanuki-8x8B
- モデル構造:ほぼMixtral-8x7Bとほぼ同一だが僅かに独自実装のある構造(
TanukiForCausalLM
) - トークナイザー:llm-jp tokenizer ver2.1を参考に作成(Tanuki-8Bのトークナイザーと同一)
- モデル構造:ほぼMixtral-8x7Bとほぼ同一だが僅かに独自実装のある構造(
今回、これらのモデルを元に以下の量子化モデルを作成しました。
- Tanuki-8B
- Tanuki-8x8B
この記事では、これらの量子化モデルの作成方法について解説します。なお、一部解決できていない問題もあり、それらの詳細は余談に記載しています。解決策ご存知の方いればコメント等で教えていただけると嬉しいです。
AWQ量子化
Tanuki-8Bの変換
Tanuki-8BのAWQによる量子化は特にライブラリの改変等なしでそのまま変換できます。
まず、AWQ量子化のためのライブラリであるAutoAWQを通常通りインストールします。
pip install autoawq
その後、以下のようなコードで変換を実行します。ここでは、キャリブレーションデータセットにizumi-lab/wikipedia-ja-20230720を用いています。
import numpy as np
from awq import AutoAWQForCausalLM
from datasets import load_dataset
from transformers import AutoTokenizer
# キャリブレーション用データセットの設定
wiki_dataset = load_dataset("izumi-lab/wikipedia-ja-20230720")
texts = wiki_dataset["train"]["text"]
rng = np.random.default_rng(42)
random_indices = rng.choice(len(texts), size=512, replace=False)
calib_dataset = [texts[i] for i in random_indices]
model_path = "weblab-GENIAC/Tanuki-8B-dpo-v1.0"
quant_path = "./Tanuki-8B-dpo-v1.0-AWQ" # 量子化モデルの出力先
# 量子化の設定
quant_config = {
"zero_point": True,
"q_group_size": 128,
"w_bit": 4,
"version": "GEMM",
}
# モデルのロード
model = AutoAWQForCausalLM.from_pretrained(model_path, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_path)
# 量子化の実行
model.quantize(tokenizer, quant_config=quant_config, calib_data=calib_dataset)
# 量子化モデルを保存
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)
print(f"Model is quantized and saved at '{quant_path}'")
Tanuki-8x8Bの変換
Tanuki-8x8BはTanukiForCausalLM
という独自アーキテクチャなので、AutoAWQライブラリを一部改変して変換に対応させる必要があります。
改変を行いTanuki-8x8Bの変換に対応したAutoAWQをこちらで公開しています。
また、改変内容についてはこちらの差分をご確認ください。
この改変版AutoAWQを使って変換を行います。まず、以下のようにAutoAWQをソースからビルドしてインストールし、flash attentionをインストールします。
git clone https://github.com/team-hatakeyama-phase2/AutoAWQ
cd AutoAWQ
pip install -e .
pip install --no-build-isolation flash_attn
その後、以下のようなコードで変換を実行します。
import numpy as np
from awq import AutoAWQForCausalLM
from datasets import load_dataset
from transformers import AutoTokenizer
# キャリブレーション用データセットの設定
wiki_dataset = load_dataset("izumi-lab/wikipedia-ja-20230720")
texts = wiki_dataset["train"]["text"]
rng = np.random.default_rng(42)
random_indices = rng.choice(len(texts), size=512, replace=False)
calib_dataset = [texts[i] for i in random_indices]
model_path = "weblab-GENIAC/Tanuki-8x8B-dpo-v1.0"
quant_path = "./Tanuki-8x8B-dpo-v1.0-AWQ" # 量子化モデルの出力先
# 量子化の設定
quant_config = {
"zero_point": True,
"q_group_size": 128,
"w_bit": 4,
"version": "GEMM",
}
# モデルのロード
model = AutoAWQForCausalLM.from_pretrained(model_path, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_path)
# 量子化の実行
model.quantize(tokenizer, quant_config=quant_config, calib_data=calib_dataset)
# 量子化モデルを保存
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)
print(f"Model is quantized and saved at '{quant_path}'")
GPTQ量子化
Tanuki-8Bの変換
Tanuki-8BのGPTQによる量子化は特にライブラリの改変等なしでそのまま変換できます。
まず、GPTQ量子化のためのライブラリであるAutoGPTQを通常通りインストールします。
pip install auto-gptq
その後、以下のようなコードで変換を実行します。ここでは、キャリブレーションデータセットにizumi-lab/wikipedia-ja-20230720を用いています。
from transformers import AutoTokenizer
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
from datasets import load_dataset
import numpy as np
pretrained_model_dir = "weblab-GENIAC/Tanuki-8B-dpo-v1.0"
quantized_model_dir = "./Tanuki-8B-dpo-v1.0-GPTQ-4bit" # 量子化モデルの出力先
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir)
# キャリブレーション用データセットの設定
wiki_dataset = load_dataset("izumi-lab/wikipedia-ja-20230720")
texts = wiki_dataset["train"]["text"]
rng = np.random.default_rng(42)
random_indices = rng.choice(len(texts), size=1000, replace=False)
calib_dataset = [texts[i] for i in random_indices]
examples = [
tokenizer(
data,
return_token_type_ids=False,
) for data in calib_dataset
]
quantize_config = BaseQuantizeConfig(
bits=4, # quantize model to 4-bit
group_size=128, # it is recommended to set the value to 128
desc_act=False, # set to False can significantly speed up inference but the perplexity may slightly bad
)
# load un-quantized model, by default, the model will always be loaded into CPU memory
# GPUを使う場合max_memoryを指定しないとエラーになったので指定しておく
model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config, max_memory={0: '48GiB', 'cpu': '99GiB'})
# quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask"
model.quantize(examples)
# save quantized model using safetensors
model.save_quantized(quantized_model_dir, use_safetensors=True)
Tanuki-8x8Bの変換
Tanuki-8x8BはTanukiForCausalLM
という独自アーキテクチャなので、AutoGPTQライブラリを一部改変して変換に対応させる必要があります。
改変を行いTanuki-8x8Bの変換に対応したAutoGPTQをこちらで公開しています。
また、改変内容についてはこちらの差分をご確認ください。
この改変版AutoGPTQを使って変換を行います。まず、以下のようにAutoAWQをソースからビルドしてインストールし、flash attentionをインストールします。
git clone https://github.com/team-hatakeyama-phase2/AutoGPTQ
cd AutoGPTQ
pip install -e .
pip install --no-build-isolation flash_attn
from transformers import AutoTokenizer
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
from datasets import load_dataset
import numpy as np
pretrained_model_dir = "weblab-GENIAC/Tanuki-8x8B-dpo-v1.0"
quantized_model_dir = "./Tanuki-8x8B-dpo-v1.0-GPTQ-4bit" # 量子化モデルの出力先
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir)
# キャリブレーション用データセットの設定
wiki_dataset = load_dataset("izumi-lab/wikipedia-ja-20230720")
texts = wiki_dataset["train"]["text"]
rng = np.random.default_rng(42)
random_indices = rng.choice(len(texts), size=1000, replace=False)
calib_dataset = [texts[i] for i in random_indices]
examples = [
tokenizer(
data,
return_token_type_ids=False,
) for data in calib_dataset
]
quantize_config = BaseQuantizeConfig(
bits=4, # quantize model to 4-bit
group_size=128, # it is recommended to set the value to 128
desc_act=False, # set to False can significantly speed up inference but the perplexity may slightly bad
)
# load un-quantized model, by default, the model will always be loaded into CPU memory
# GPUを使う場合max_memoryを指定しないとエラーになったので指定しておく
model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config, max_memory={0: '80GiB', 1: '80GiB', 2: '80GiB', 'cpu': '99GiB'})
# quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask"
model.quantize(examples)
# save quantized model using safetensors
model.save_quantized(quantized_model_dir, use_safetensors=True)
GGUF量子化
llama.cppの環境準備
まず、変換のためのllama.cppの環境を整備します。今回はWindows 10の環境で、こちらのドキュメント通りに進めます。
-
llama.cppのリポジトリをクローン
git clone https://github.com/ggerganov/llama.cpp cd llama.cpp
-
python環境の作成
python -m venv venv venv\Scripts\activate pip install -r requirements.txt
-
llama.cppのbuild
こちらから、w64devkitをダウンロードして実行llama.cppのフォルダに移動し、
make
コマンドを実行しbuild
これでllama.cppの環境が準備できたので、これを元に変換します。
Tanuki-8Bの変換
まずはそのままTanuki-8Bを変換しようとしてみます。convert_hf_to_gguf.py
を実行することでGGUFに変換が出来ます。
python convert_hf_to_gguf.py local_dir\Tanuki-8B-dpo-v1.0 --outfile .\Tanuki-8B-dpo-v1.0-F16.gguf --outtype f16
これを行うと、以下のWARNINGとエラーが発生しうまく変換できないはずです。これは、Tanukiのtokenizerの変換にllama.cppがデフォルトで対応していないからです。
WARNING:hf-to-gguf:**************************************************************************************
WARNING:hf-to-gguf:** WARNING: The BPE pre-tokenizer was not recognized!
WARNING:hf-to-gguf:** There are 2 possible reasons for this:
WARNING:hf-to-gguf:** - the model has not been added to convert_hf_to_gguf_update.py yet
WARNING:hf-to-gguf:** - the pre-tokenization config has changed upstream
WARNING:hf-to-gguf:** Check your model files and convert_hf_to_gguf_update.py and update them accordingly.
WARNING:hf-to-gguf:** ref: https://github.com/ggerganov/llama.cpp/pull/6920
WARNING:hf-to-gguf:**
WARNING:hf-to-gguf:** chkhsh: a12ac8faf6a5e2ef542d8c05946c7c89443346927f0c04b8d0f285c557864f24
WARNING:hf-to-gguf:**************************************************************************************
WARNING:hf-to-gguf:
NotImplementedError: BPE pre-tokenizer was not recognized - update get_vocab_base_pre()
正攻法ではconvert_hf_to_gguf_update.py
を使ってこれを解決するのですが、余談に記載の通りこの方法では上手く解決できなかったので、別のアプローチをとります。
Phase2のモデルのtokenizerはPhase1のものと同一ですが、phase1のモデルをGGUFに変換していただいている方がいました。このGGUFファイルのメタデータを見るとtokenizer.ggml.model = llama
になっていますが、現在のllama.cppではこのように変換されません。そのため、これがllama
として変換されるように半分無理やり改変します。
llama.cpp/gguf-py/gguf/vocab.py
の中の処理から一部を以下のようにコメントアウトします。
class LlamaHfVocab(Vocab):
tokenizer_model = "llama"
name = "hfft"
def __init__(self, base_path: Path):
(中略)
is_llama3 = (
tokenizer_model['type'] == 'BPE' and tokenizer_model.get('ignore_merges', False)
and not tokenizer_model.get('byte_fallback', True)
)
if is_llama3:
raise TypeError('Llama 3 must be converted with BpeVocab')
# ここをコメントアウト
# if not is_llama3 and (
# tokenizer_model['type'] != 'BPE' or not tokenizer_model.get('byte_fallback', False)
# or tokenizer_json['decoder']['type'] != 'Sequence'
# ):
# raise FileNotFoundError('Cannot find Llama BPE tokenizer')
try:
from transformers import AutoTokenizer
except ImportError as e:
raise ImportError(
"To use LlamaHfVocab, please install the `transformers` package. "
"You can install it with `pip install transformers`."
) from e
(以下略)
この改変をすることで、tokenizer.ggml.model = llama
として変換することが出来ます。この改変をした後改めてGGUFに変換し推論を実行すると、問題なくモデルロードと推論ができていることが分かります。
python convert_hf_to_gguf.py local_dir\Tanuki-8B-dpo-v1.0 --outfile .\Tanuki-8B-dpo-v1.0-F16.gguf --outtype f16
llama-cli -m Tanuki-8B-dpo-v1.0-F16.gguf -p "I believe the meaning of life is" -n 128
> I believe the meaning of life is to find your purpose, and then live your life as that purpose requires. This is a common theme in various philosophical and existential discussions.(以下略)
試しにJapanese MT-Benchにある問題を与えてみると、出力が壊れることなく推論できていることが分かります。
なお、今回のやり方でtokenizerが正しく変換出来ているかは不明です。実際にはうまく変換出来ておらず何らかの形で性能低下が発生している可能性があり、そのためGGUFは他の量子化モデルに対して非推奨としています。
Tanuki-8x8Bの変換
tokenizerは8Bと同じものなので、tokenizerの変換の対応は既に完了しています。一度そのまま8x8Bの方も変換してみます。
python convert_hf_to_gguf.py local_dir\Tanuki-8x8B-dpo-v1.0 --outfile .\Tanuki-8x8B-dpo-v1.0-F16.gguf --outtype f16
これを実行すると以下のようなエラーが出るはずです。これは、Tanuki-8x8BのTanukiForCausalLM
という独自アーキテクチャの変換にデフォルトで対応していない事が原因です。
ERROR:hf-to-gguf:Model TanukiForCausalLM is not supported
TanukiForCausalLM
は重みの構造自体についてはMixtralForCausalLM
と同一なので、GGUFへの変換はこれの変換と全く同じ処理をすれば良いです。具体的には、convert_hf_to_gguf.py
の中で以下のようにMixtralForCausalLM
の後にTanukiForCausalLM
をデコレータに追加すれば、Mixtralと同じ処理での変換が可能になります。
@Model.register("LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM", "TanukiForCausalLM")
class LlamaModel(Model):
model_arch = gguf.MODEL_ARCH.LLAMA
(以下略)
この変換をした後、同様の処理で変換が出来ます。ここではF16ではなくQ8_0に変換しています。
python convert_hf_to_gguf.py local_dir\Tanuki-8x8B-dpo-v1.0 --outfile .\Tanuki-8x8B-dpo-v1.0-Q8_0.gguf --outtype q8_0
このまま推論してみます。
llama-cli -m Tanuki-8x8B-dpo-v1.0-Q8_0.gguf -p "I believe the meaning of life is" -n 128
> I believe the meaning of life is to not believe in God, and if we believe in a higher being then we will believe that our life will have a purpose.(以下略)
8Bの時と同様に、試しにJapanese MT-Benchにある問題を与えてみると、出力が壊れることなく推論できていることが分かります。(画像の例はQ4_K_M)
なお、カスタムモデルの推論に関する実装を特にしていないので、このGGUFの推論時にはTanukiForCausalLM
ではなくMixtralForCausalLM
として推論されていますが、出力は崩壊していません。これは、TanukiForCausalLM
がMixtralForCausalLM
に加えている変更が比較的小さなものであるからだと推測できます。ただし、Mixtralとして推論されることによってJapanese MT-Benchで-0.5点程度の性能低下が確認されており、GGUF版は非推奨としております。
TanukiForCausalLM
として推論するためには、llama.cppの推論部分に変更を加える必要があると考えられます。こちらについては実装を試行してみましたが、現状上手く行っていません。詳細は余談に書いていますので、もし詳しい方いればコメントいただけると助かります。
まとめ
本記事では、Tanuki-8BとTanuki-8x8Bの各種手法による量子化についてまとめました。今後企業や研究機関等で独自アーキテクチャのモデルを開発されることがあると思いますが、そのモデルの量子化を行う際に少しでも参考になれば幸いです。また、まだ上手く変換出来ていない部分もあるので、知見のある方はコメント等でご教示いただけますと幸いです。
余談
ここでは、試したが上手く行かなかったことを書いています。記事を読んだ方の中で詳しい方がいれば何かコメントいただけると助かります。
pre-tokenizationの設定をしてTanukiをGGUF変換してみる
上述したように、TanukiモデルをそのままGGUF変換しようとすると以下のような警告とエラーが出ます。
WARNING:hf-to-gguf:**************************************************************************************
WARNING:hf-to-gguf:** WARNING: The BPE pre-tokenizer was not recognized!
WARNING:hf-to-gguf:** There are 2 possible reasons for this:
WARNING:hf-to-gguf:** - the model has not been added to convert_hf_to_gguf_update.py yet
WARNING:hf-to-gguf:** - the pre-tokenization config has changed upstream
WARNING:hf-to-gguf:** Check your model files and convert_hf_to_gguf_update.py and update them accordingly.
WARNING:hf-to-gguf:** ref: https://github.com/ggerganov/llama.cpp/pull/6920
WARNING:hf-to-gguf:**
WARNING:hf-to-gguf:** chkhsh: a12ac8faf6a5e2ef542d8c05946c7c89443346927f0c04b8d0f285c557864f24
WARNING:hf-to-gguf:**************************************************************************************
WARNING:hf-to-gguf:
NotImplementedError: BPE pre-tokenizer was not recognized - update get_vocab_base_pre()
上の解説ではtokenizerのtypeを変えて解決しましたが、基本的にこのエラーの解決は警告にもあるようにconvert_hf_to_gguf_update.py
を使って行います。ただ、Tanukiはこの方法では推論時にエラーが出てしまい上手く変換できませんでした。ここではその変換処理について書きます。
convert_hf_to_gguf_update.py
を開き、models
に以下のようにTanukiのモデルを追加します。
models = [
...
{'name': "gpt3-finnish", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/TurkuNLP/gpt3-finnish-small", },
{"name": "exaone", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", },
{"name": "default", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/weblab-GENIAC/Tanuki-8B-dpo-v1.0", },
]
*本来はここでnameにモデル名等を記載しますが、新規で追加する場合それに対応するpre-tokenizationの処理の追加がllama.cpp側に必要となります。今回は一旦これをスキップし、defaultとしておきます。
その後、以下のように実行します。
python convert_hf_to_gguf_update.py <huggingface_token>
このconvert_hf_to_gguf_update.py
では、様々な記号や言語が混ざった乱雑なテキストを各モデルのtokenizerでencodeし、その結果からSHA256を使ってハッシュ値を取得し、それを元に各tokenizer用の設定をconvert_hf_to_gguf.py
に追加しているようです。
これにより、convert_hf_to_gguf.py
のget_vocab_base_pre()
関数の中に以下のようなif文が自動で追加されます。
if chkhsh == "002fd46bc80da6b186f8e6cf447170310e239de9233f081a9495150c0d0a8e42":
# ref: https://huggingface.co/weblab-GENIAC/Tanuki-8B-dpo-v1.0
res = "default"
ここで、もう一度convert_hf_to_gguf.py
を実行してみると、無事に変換は出来るはずです。
python convert_hf_to_gguf.py local_dir\Tanuki-8B-dpo-v1.0 --outfile .\Tanuki-8B-dpo-v1.0-F16.gguf --outtype f16
変換が無事に出来たので、推論しようとしてみます。
llama-cli -m Tanuki-8B-dpo-v1.0-F16.gguf -p "I believe the meaning of life is" -n 128
すると、以下のようなエラーが出てしまいモデルがロード出来ません。
llama_model_load: error loading model: error loading model vocabulary: cannot find tokenizer merges in model file
llama_load_model_from_file: failed to load model
llama_init_from_gpt_params: error: failed to load model 'Tanuki-8B-dpo-v1.0-F16.gguf'
この問題について少し調べましたが、私の知識では原因がよく分かりませんでした。もし詳しい方いればコメント等でおしえていただきたいです。
llama.cppの推論側をTanuki-8x8Bのアーキテクチャに対応させる
Tanuki-8x8BはTanukiForCausalLM
という独自アーキテクチャですが、これはMixtralForCausalLM
とほぼ同一の構造です。改変部分はSparse MoE層においてrouter_logitsの正規化を行っている部分だけで、コード的には以下の3行の改変のみがされています。
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
""" """
batch_size, sequence_length, hidden_dim = hidden_states.shape
if self.training and self.jitter_noise > 0:
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
# 改変部分
mean = router_logits.mean(dim=-1, keepdim=True)
std = router_logits.std(dim=-1, keepdim=True)
router_logits = (router_logits - mean) / (std + 1e-5)
# 改変部分ここまで
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
(以下略)
非常に軽微な変更なので、この変更部分をllama.cppのMixtralの推論部分に取り込めばTanukiのアーキテクチャに対応した推論が可能になるはずです。
この考えのもと、上記のrouter_logitsの正規化に値する処理をllama.cpp側に取り入れようとしてみました。具体的には、以下のような変更を加えてみています。
この変更を加えたllama.cppで、Tanuki-8x8BのGGUF版のCPU推論が問題なく行えることは確認済みです。ただ、GPU推論をすると以下のようなCUDAエラーが出てしまい、上手く行っていません。
ggml_cuda_compute_forward: ARGSORT failed
CUDA error: an illegal memory access was encountered
current device: 0, in function ggml_cuda_compute_forward at /tmp/pip-req-build-lf0oo0av/vendor/llama.cpp/ggml/src/ggml-cuda.cu:2326
err
/tmp/pip-req-build-lf0oo0av/vendor/llama.cpp/ggml/src/ggml-cuda.cu:102: CUDA error
Aborted (core dumped)
エラーを見るとARGSORTが失敗しているようですが、原因が分かっておらず解決できていません。もし詳しい方で何か対処方法ご存知の方いればコメント等で教えていただけると非常に助かります。
東京大学 松尾・岩澤研究室が運営する松尾研LLMコミュニティのLLM開発プロジェクト[GENIAC] の開発記録、情報発信になります。 各種リンクはこちら linktr.ee/matsuolab_community
Discussion