MoEモデルのアクティブパラメータ数の厳密な計算方法について
GENIAC 松尾研LLM開発プロジェクトメンバーのArataです。
この記事は、MoEモデルのアクティブパラメータ数の厳密な計算方法について解説する備忘録的な記事です。
はじめに
Tanuki-8x8BはMoEモデルであり、総パラメータ数とアクティブパラメータ数が異なります。モデルカード等に情報を記載する際、総パラメータ数だけでなくアクティブパラメータ数も記載する必要があり、厳密なアクティブパラメータ数を算出する必要がありました。その際に利用した算出方法について本記事では解説します。
MoEモデルの構造について
今回我々のチームで開発されたTanuki-8x8Bは構造上ではMixtral-8x7Bと同一です。
Tanuki-8x8Bという名称を聞くと、パラメータ数について「8x8Bだからだいたい64Bくらいだろう」と思われる方もいるかもしれませんが、実際の総パラメータ数は約47B程度です。
これはなぜかというと、実際には8Bモデルを8個そのまま並列に並べているのではなく、モデルのFFN層のみを8個Expertsとして並べSparse MoE層としているからです。その他のEmbedding層やAttention層は共有されているため、総パラメータ数は単純な掛け算よりも小さくなります。
モデルの厳密な総パラメータ数は簡単に取得できます。
まずは環境を準備します。
pip install transformers accelerate
pip install --no-build-isolation flash_attn # Tanuki-8x8Bで行う場合
その後、以下のようなコードを実行します。
from accelerate import init_empty_weights
from transformers import AutoConfig, AutoModelForCausalLM
model_name = "weblab-GENIAC/Tanuki-8x8B-dpo-v1.0"
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
print("総パラメータ数: " + str(model.num_parameters()))
これを実行すると総パラメータ数が分かります。Tanuki-8x8Bの場合は46973325312(約47B) でした。
アクティブパラメータ数の算出方法
MoEモデルは推論の際、一般的に各トークンごとに一部のExpertsのみが内部的に利用されます。例えば、Mixtral-8x7BやTanuki-8x8Bでは同時に使われるExpertsの数は2つです。
この値はモデルのconfig.jsonの中のnum_experts_per_tok
の値で確認できます。(余談ですが、configのこの値を変えることで推論時に同時にアクティブになるExpertsの数を実際に変更できます)
つまり、推論時実際には全てのパラメータがアクティブになるわけではなく、一部のパラメータしか使われません。この場合、2つのExperts+その他の共有部分が使われることになります。そのため、実際の推論コストは47BのDenseモデルよりもかなり小さいです。
この「推論時実際に同時に使われるパラメータ数」がアクティブパラメータ数と呼ばれる値です。
このアクティブパラメータ数の算出自体もそこまで難しくありません。
先ほども言ったようにこの例では「2つのExperts+その他の共有部分」がアクティブになるので、これらのパラメータ数を合算すれば良いです。
これは逆に言うと、総パラメータ数から推論時に使われない6つのExpertsのパラメータ数を引けば良いとも言えます。
この考えを元に実際に算出してみます。以下のようなコードを使います。
from accelerate import init_empty_weights
from transformers import AutoConfig, AutoModelForCausalLM
model_name = "weblab-GENIAC/Tanuki-8x8B-dpo-v1.0"
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
# 総パラメータ数を算出
total_params = model.num_parameters()
print("総パラメータ数: " + str(total_params))
# 1つのExpertのパラメータ数を算出
single_layer_expert = model.model.layers[0].block_sparse_moe.experts[0]
single_layer_expert_params = sum(p.numel() for p in single_layer_expert.parameters())
total_single_expert_params = single_layer_expert_params * config.num_hidden_layers
# アクティブパラメータ数を算出
# アクティブパラメータ数 = 総パラメータ数 - (Expertsの総数 - 同時に使われるExpertsの数) * 1つのExpertのパラメータ数
active_params = (
total_params
- (config.num_local_experts - config.num_experts_per_tok)
* total_single_expert_params
)
print("アクティブパラメータ数: " + str(active_params))
これを実行するとアクティブパラメータ数が分かります。Tanuki-8x8Bの場合は13150457856(約13B) でした。
まとめ
この記事では、Tanuki-8x8Bを実際の例にしてMoEモデルのアクティブパラメータ数の厳密な求め方について解説しました。今後MoEモデルを開発される際の参考になれば幸いです。
東京大学 松尾・岩澤研究室が運営する松尾研LLMコミュニティのLLM開発プロジェクト[GENIAC] の開発記録、情報発信になります。 各種リンクはこちら linktr.ee/matsuolab_community
Discussion