👻

学習済みの LLM を束ねて Mixture of Experts を作るテク

2024/01/22に公開

導入

Twitter でこんな投稿を見かけました。

https://x.com/maximelabonne/status/1744867841436700850?s=20

「Phi-2 ベースのモデルをいくつか使って Mixture of Experts (MoE) を作ったら単体よりも良い性能が達成できました」という話です。学習済み LLM をマージするテクに関しては最近時々話題に上がっているのを見かけますが、MoE には Gating 部分で追加のパラメータが必要なはずで、そこはどうやっているんだろうと気になりました。中身を見てみたところ、Few-shot で Gating のパラメータを決める手法が使われていて面白かったので、それについて書いてみます。

Sparse Mixture of Experts (Sparse MoE) の推論時の処理

Phixtral は名前やワードアートからも分かる通り、Mixtral の Sparse MoE を踏襲しているので、まずその推論時の処理について書きます。

Transformer の基本的なアーキテクチャは、すごくざっくり描くと下記のような感じで、この内 DecoderLayer 内の MLP の部分を MoE Layer に置き換えます。

  • Embedding
  • DecoderLayer x N (e.g. Mistral 7B なら N = 32)
    • LayerNorm
    • Attention
    • LayerNorm
    • MLP
    • Layer の入力と MLP の出力の Add
  • タスクに応じた出力 Layer

MoE Layer については実装を見るとわかりやすくて、

class MoE(nn.Module):
    def __init__(
        self,
        config: PretrainedConfig,
    ):
        super().__init__()
        self.mlp = nn.ModuleList([MLP(config) for i in range(config.num_local_experts)])
        self.gate = nn.Linear(config.n_embd, config.num_local_experts, bias=False)
        self.num_experts_per_tok = config.num_experts_per_tok

    def forward(self, x):
        orig_shape = x.shape
        x = x.view(-1, x.shape[-1])

        scores = self.gate(x)
        expert_weights, expert_indices = torch.topk(scores, self.num_experts_per_tok, dim=-1)
        expert_weights = expert_weights.softmax(dim=-1)
        flat_expert_indices = expert_indices.view(-1)

        x = x.repeat_interleave(self.num_experts_per_tok, dim=0)
        y = torch.empty_like(x)
        for i, expert in enumerate(self.mlp):
            y[flat_expert_indices == i] = expert(x[flat_expert_indices == i])
        y = (y.view(*expert_weights.shape, -1) * expert_weights.unsqueeze(-1)).sum(dim=1)
        return y.view(*orig_shape)

https://huggingface.co/mlabonne/phixtral-2x2_8/blob/7744a977d83f132ae5808d8c3b70157031f7de44/modeling_phi.py#L293-L317

  • expert の数だけ MLP を用意する
  • 入力を bias なしの Linear で変換してどの Expert に割り当てるかのスコアを計算する
  • スコアの大きい K (= num_experts_per_tok) 個の Expert で推論をして、スコアを softmax して作った重みで重み付き和を取る

というのを Token ごとに独立に 行います。全 Expert について推論は実行されますが、各 Expert は Token の Subset について推論することになるので、MoE Layer の全体の計算量は Gating のオーバーヘッドを除けば変更前の単一の MLP の K 倍になります。Sequence 全体で見ると全 Expert にアクセスできるのに、計算量は K 倍で済むというのがおもしろポイントだと思います。

学習無しで Gating のパラメータを決める手法

Phixtral の話に戻ります。Phixtral (mlabonne/phixtral-4x2_8) では、Phi-2 ベースの 4 つの異なるモデルから MoE モデルを作成しています。

Gating の話を忘れれば「ベースのモデルを決めて MLP 以外のパラメータは全部ベースモデルのものを、MLP は MoE Layer に置き換えて各モデルの MLP のパラメータを使う」という方法で MoE モデルが作れそうです[1]。では、Gating のパラメータはどのように決めているのでしょうか。

真面目にデータセットを用意して Gating のパラメータだけ Fine-Tuning するというのが素朴な方法になりそうですが、下記の記事では Few-shot で Gating のパラメータを決める手法が提案されています。

https://goddard.blog/posts/clown-moe/

この手法は mergekit の mixtral branch の mixtral_moe.py に実装されていて、Phixtral はこれを使って作成されているそうです。

アイデアは以下のような感じです

  • スコアは bias なしの Linear で計算されているので、Gating のパラメータは各 Expert e に対応する vector w_e を stack したもので、MLP への入力 x との内積 x \cdot w_e が大きい Expert が選ばれると見なせる。
  • 内積は x \cdot w_e = |x||w_e| \cos \theta で、スコアを他の Expert と比較するという観点では |x| は共通で、|w_e| もどれかの Expert を常に有利にしたいとかがなければ共通にするのが自然なので、スコアの大小は cosine 類似度の大小になる。
  • cosine 類似度が大きくなる vector というのは自分自身なので、その Expert を使うと有利になりそうな入力 x を集めてきて、適当に集約すれば w_e としてそこそこ良い vector が得られそう。
  • 各 Expert について、その Expert を使うと有利になりそうな Prompt (例えば Code で Fine-Tuning された Expert なら Code の Prompt) をいくつか用意して、その Prompt を forward したときの hidden_state を使って w_e を作ろう!

Domain ごとに Expert を使い分けてくれることを期待する感じですね。Mixtral の MoE は Domain ごとに使い分けているわけではなさそうらしい[2]ので、Mixtral とは結構違う MoE の使い方になりそうです。

mergekit ではいくつか実装のバリエーションが提供されていて config の gate_mode で変更できるようになっています。
https://github.com/cg123/mergekit/blob/9bbee12fe7cc569c97f4d9b0ddb4cccff4081d92/mergekit/scripts/mixtral_moe.py#L47

gate_mode == "hidden" が上で説明したようなアイデアを実装したものになっています。
https://github.com/cg123/mergekit/blob/9bbee12fe7cc569c97f4d9b0ddb4cccff4081d92/mergekit/scripts/mixtral_moe.py#L55-L73

これと "random" の他に "cheap_embed" というのも実装されていて、こちらは transformer の Embedding 層だけ使って計算する計算コストの軽い方法になっています。全部の MoE Layer で同じ Gating を行うことになりますね。

https://github.com/cg123/mergekit/blob/9bbee12fe7cc569c97f4d9b0ddb4cccff4081d92/mergekit/scripts/mixtral_moe.py#L76-L94

また、Negative Prompt が用意できる場合は、その結果を引くような処理も実装されていました。
https://github.com/cg123/mergekit/blob/9bbee12fe7cc569c97f4d9b0ddb4cccff4081d92/mergekit/scripts/mixtral_moe.py#L159-L163

めっちゃ Practical な感じで面白いですね。

hidden_state は MLP に入力される直前のものをもってこなくて大丈夫なのかとか、forward するのは base_model でいいのかとか、softmax に突っ込む前に scaling しなくても大丈夫か、みたいな細かいところが気になったりもしますが、とりあえず一つうまくいくセッティングを示してくれているのは凄いことだと思います。

Phixtral ではどうなっているか

Phixtral では具体的にどのような prompts を指定しているのか、config を確認してみましょう

base_model: cognitivecomputations/dolphin-2_6-phi-2
gate_mode: cheap_embed
experts:
  - source_model: cognitivecomputations/dolphin-2_6-phi-2
    positive_prompts: [""]
  - source_model: lxuechen/phi-2-dpo
    positive_prompts: [""]
  - source_model: Yhyu13/phi-2-sft-dpo-gpt4_en-ep1
    positive_prompts: [""]
  - source_model: mrm8488/phi-2-coder
    positive_prompts: [""]

https://huggingface.co/mlabonne/phixtral-4x2_8/blob/e9dad464394da163595176f6897c2a4f88761c63/mergekit_moe_config.yml

ん????!!!!??!?!?

全モデル同じ空文字列なので、全 Expert の Gating パラメータが同じになって常に最初の2個が選択されるようになっていますね... 😇

(Discussion でもそのことについて触れられていました)
https://huggingface.co/mlabonne/phixtral-4x2_8/discussions/6#65a05f17fbad78ab68e64fc7

Evaluation の結果がいい感じだったので、てっきり 4 モデルをいい感じに使い分けられているのかと思ったんですが、常に同じ 2 モデルを使うのでもすでに単体よりも良いということなんですね...

(https://huggingface.co/mlabonne/phixtral-4x2_8#🏆-evaluation)

まとめ・所感

Phixtral で使われている、LLMs を Sparse MoE としてマージする際に Gating のパラメータを決める手法について調べました。LLM をマージする hueristic な手法として、mergekit の main では SLERP (球面線形補間) や DARE, Passthrough などが実装されていますが、こんな方法も提案されているんだなぁという感じです。違うドメインで訓練された Expert モデルを MoE で使い分けるというのは (何なら Mixtral の MoE よりも) Mixture of Experts という言葉のイメージと合っている印象で面白いなぁと思います。

phixtral-4x2_8 は (現時点では) 常に同じ2個しか使っていなかったというオチでしたが、常に同じ 2 つを選ぶような MoE でも性能の向上が見られるということは、MoE でアンサンブル的な効果が得られているということで、それはそれで興味深い結果だと思います。たくさんのモデルをアンサンブルしたいけど推論時間は足りないので Sparse MoE をよしなに使って計算時間を抑えながら効果的なアンサンブルをする、みたいな方法が Kaggle の解法で使われる日が来ないかなーみたいなことを妄想しました。

追記

かなり詳しい write up (?) が公開されていることを教えていただきました。
https://x.com/iwiwi/status/1749372742380683599?s=20

このタイプの MoE はしばしば FrankenMoE と呼ばれているそうです。ドキュメントにはマージがうまくいく/いかないモデルに関する考察や、Positive Prompt を生成するのにモデル自身を使うと良いことなどが書かれていて非常に参考になる内容でした。
https://docs.google.com/document/d/1_vOftBnrk9NRk5h10UqrfJ5CDih9KBKL61yvrZtVWPE/edit

脚注
  1. これは雑なことを言っていて、MLP の出力同士を足し合わせるのは Skip Connection があるのでギリわかるとしても、マージしたい対象の Fine-Tuning されたモデルでは MLP 以外のパラメータも Tuning されて変わっていると思うので、MLP だけ持ってきてうまくいくというのはかなり非自明だと思います... ↩︎

  2. https://arxiv.org/abs/2401.04088 ↩︎

Discussion