🕌

MoEモデルのアクティブパラメータ数の厳密な計算方法について

2024/09/04に公開

GENIAC 松尾研LLM開発プロジェクトメンバーのArataです。
この記事は、MoEモデルのアクティブパラメータ数の厳密な計算方法について解説する備忘録的な記事です。

はじめに

Tanuki-8x8BはMoEモデルであり、総パラメータ数とアクティブパラメータ数が異なります。モデルカード等に情報を記載する際、総パラメータ数だけでなくアクティブパラメータ数も記載する必要があり、厳密なアクティブパラメータ数を算出する必要がありました。その際に利用した算出方法について本記事では解説します。

MoEモデルの構造について

今回我々のチームで開発されたTanuki-8x8Bは構造上ではMixtral-8x7Bと同一です。

Tanuki-8x8Bという名称を聞くと、パラメータ数について「8x8Bだからだいたい64Bくらいだろう」と思われる方もいるかもしれませんが、実際の総パラメータ数は約47B程度です。

これはなぜかというと、実際には8Bモデルを8個そのまま並列に並べているのではなく、モデルのFFN層のみを8個Expertsとして並べSparse MoE層としているからです。その他のEmbedding層やAttention層は共有されているため、総パラメータ数は単純な掛け算よりも小さくなります。

モデルの厳密な総パラメータ数は簡単に取得できます。
まずは環境を準備します。

pip install transformers accelerate
pip install --no-build-isolation flash_attn # Tanuki-8x8Bで行う場合

その後、以下のようなコードを実行します。

from accelerate import init_empty_weights
from transformers import AutoConfig, AutoModelForCausalLM

model_name = "weblab-GENIAC/Tanuki-8x8B-dpo-v1.0"

config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)

with init_empty_weights():
    model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
print("総パラメータ数: " + str(model.num_parameters()))

これを実行すると総パラメータ数が分かります。Tanuki-8x8Bの場合は46973325312(約47B) でした。

アクティブパラメータ数の算出方法

MoEモデルは推論の際、一般的に各トークンごとに一部のExpertsのみが内部的に利用されます。例えば、Mixtral-8x7BやTanuki-8x8Bでは同時に使われるExpertsの数は2つです。
この値はモデルのconfig.jsonの中のnum_experts_per_tok の値で確認できます。(余談ですが、configのこの値を変えることで推論時に同時にアクティブになるExpertsの数を実際に変更できます)

つまり、推論時実際には全てのパラメータがアクティブになるわけではなく、一部のパラメータしか使われません。この場合、2つのExperts+その他の共有部分が使われることになります。そのため、実際の推論コストは47BのDenseモデルよりもかなり小さいです。
この「推論時実際に同時に使われるパラメータ数」がアクティブパラメータ数と呼ばれる値です。

このアクティブパラメータ数の算出自体もそこまで難しくありません。
先ほども言ったようにこの例では「2つのExperts+その他の共有部分」がアクティブになるので、これらのパラメータ数を合算すれば良いです。
これは逆に言うと、総パラメータ数から推論時に使われない6つのExpertsのパラメータ数を引けば良いとも言えます。

この考えを元に実際に算出してみます。以下のようなコードを使います。

from accelerate import init_empty_weights
from transformers import AutoConfig, AutoModelForCausalLM

model_name = "weblab-GENIAC/Tanuki-8x8B-dpo-v1.0"

config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)

with init_empty_weights():
    model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)

# 総パラメータ数を算出
total_params = model.num_parameters()
print("総パラメータ数: " + str(total_params))

# 1つのExpertのパラメータ数を算出
single_layer_expert = model.model.layers[0].block_sparse_moe.experts[0]
single_layer_expert_params = sum(p.numel() for p in single_layer_expert.parameters())
total_single_expert_params = single_layer_expert_params * config.num_hidden_layers

# アクティブパラメータ数を算出
# アクティブパラメータ数 = 総パラメータ数 - (Expertsの総数 - 同時に使われるExpertsの数) * 1つのExpertのパラメータ数
active_params = (
    total_params
    - (config.num_local_experts - config.num_experts_per_tok)
    * total_single_expert_params
)
print("アクティブパラメータ数: " + str(active_params))

これを実行するとアクティブパラメータ数が分かります。Tanuki-8x8Bの場合は13150457856(約13B) でした。

まとめ

この記事では、Tanuki-8x8Bを実際の例にしてMoEモデルのアクティブパラメータ数の厳密な求め方について解説しました。今後MoEモデルを開発される際の参考になれば幸いです。

東大松尾・岩澤研究室 | LLM開発 プロジェクト[GENIAC]

Discussion