📘

Mixture of expertsの簡単な実装をしてみる

2023/12/19に公開

LLM Advent Calendar 2023の記事です
https://qiita.com/advent-calendar/2023/llm

mixture of expertsの調査と簡単な実装を行ってMoEを理解していきます。
実装については並列化や計算効率の向上などの部分は複雑なので、それら取り除いた簡単なもの

Mixture of expertsについて

特定のタスクに特化したexpertを複数用意し、入力に対してexpertを切り替えることで性能を上げる手法。expertを選択するgate部分とexpert部分からなる。

decoder型のtransformerの場合、Mixture of expertsはattention層のあとのFFNに対して適応される。

(論文の図を参照)
https://arxiv.org/abs/2101.03961

sparse MoE

https://arxiv.org/abs/2101.03961

gate関数が上位k個のexpertを選択する。選択するexpert以外の部分は0になるようなスパース性を持つ。 0が含まれることで選択されないexpertでは計算する必要がないため、計算コストが削減される

通常のMoEでは、特定のexpertが多く選択されてしまうことがあり、これにより学習が非効率なものとなってしまう。これを軽減するため、auxiliary lossを導入している

Nはexpertの総数
alphaはハイパーパラメタ

f_iexpert_iにトークンが割り振られた割合
P_iexpert_iにトークンが割り振られる可能性の確率

f_iP_iが均等であれば、ルーティングも均等になるというような、トークンがexpertに対して均等に割り振られることを期待するlossとなっている。

mixtralのlossの実装はこんな感じ
https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/models/mixtral/modeling_mixtral.py#L76

ちなみに、mixtralではexpertの部分にMLPを使っている
https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/models/mixtral/modeling_mixtral.py#L664

ST-MoE

Stable and Transferable Mixture-of-Expertsの提案
(設計ガイドとして良さそうなのであとでちゃんと読む)

https://arxiv.org/abs/2202.08906

auxiliary lossに加え、z-lossを導入している。

Bはトークン数、Nはexpertの総数、xはgateの出力

あるexpertに割り振られる値が大きくなりすぎるのを防ぐことを期待するlossとなっている。

実装は以下で公開されている
https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/moe.py

実装

今回はsparseでなく、もっと単純なもので動きを確認する。

import torch
import torch.nn as nn
import torch.nn.functional as F

class MoE(nn.Module):
    def __init__(self, 
                 embed_size=128, 
                 num_experts=4, 
                 expert_hidden_size=256, 
                 gate_hidden_size=256,
                 device="auto",
                 ):
        super(MoE, self).__init__()
        self.embed_size = embed_size
        self.num_experts = num_experts

        expert = nn.Sequential(
                nn.Linear(embed_size, expert_hidden_size),
                nn.ReLU(),
                nn.Linear(expert_hidden_size, embed_size)
            )
        
        self.experts = nn.ModuleList([expert for _ in range(num_experts)])

        self.gate =  nn.Sequential(
            nn.Linear(embed_size, gate_hidden_size),
            nn.ELU(),
            nn.Linear(gate_hidden_size, num_experts),
        )

    def forward(self, x):
        """
        Forward pass for MoE
        :param x: Input tensor of shape (batch_size, seq_len, embed_size)
        :return: Output tensor. shape (batch_size, seq_len, embed_size)
        """
        # batch_size, seq_len, _ = x.size()
        
        gating_scores = self.gate(x)
        gating_weights = F.softmax(gating_scores, dim=2) # (batch_size, seq_len, num_experts)
        print('gating_weights', gating_weights.shape, gating_weights)
        
        expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=2) # (batch_size, seq_len, num_experts, embed_size)
        ## gating_weights.unsqueeze(-1) # (batch_size, seq_len, num_experts, 1)
        output = torch.sum(gating_weights.unsqueeze(-1) * expert_outputs, dim=2) #  (batch_size, seq_len, embed_size)        
        return output

実装自体はシンプルで、gate関数の出力をsoftmaxを取り、その比率を各expertの出力に対して適応する。gateによってexpertが優遇されるイメージ

例えば、以下のような例を入力とする

sequence_length=3
埋め込みサイズ=4

[I am running] 
=> [
[0.4081, 0.7244, 0.0221, 0.1677], => I
[0.2508, 0.8103, 0.3783, 0.0399], => am
[0.6584, 0.4243, 0.5941, 0.6102]  => running
]

(実際には単語の埋め込み次元ではなく、attention weightのhiddenになるので注意)

expert=2とすると、

入力は、shape=(batch_size,sequence_length, 埋め込みサイズ)
gating_weightsは、sequence_length毎にどのexpertに割り振られるかの確率となっている。
出力は、入力と同様のshape

input: torch.Size([1, 3, 4]) 
tensor([[[0.4081, 0.7244, 0.0221, 0.1677],
         [0.2508, 0.8103, 0.3783, 0.0399],
         [0.6584, 0.4243, 0.5941, 0.6102]]])

gating_weights torch.Size([1, 3, 2]) 
tensor([[[0.5903, 0.4097],
         [0.6000, 0.4000],
         [0.6028, 0.3972]]], grad_fn=<SoftmaxBackward0>)

output: torch.Size([1, 3, 4]) 
tensor([[[ 0.5622, -0.0714,  0.0197, -0.2871],
         [ 0.6389, -0.1067, -0.0396, -0.2993],
         [ 0.4812, -0.0484,  0.0448, -0.4018]]], grad_fn=<SumBackward1>)

https://github.com/if001/lit-llama-ja/blob/moe/lit_llama/moe_module.py

所感

簡単な実装を行ってMoEについて理解が深まった

transformerに組み込まれるMoEの場合、学習時にattention層とexpertやgate部分は同時に学習が行われる。専門家を用意して割り振りを学習するようなイメージと異なるので、専門家感があまりない気がするというか、一応loss関数であるexpertだけに集中しないようになどの工夫はあるものの、ちゃんとタスク毎や特徴に分かれるようにexpertが生まれるのだろうか...

mixtralの実装が公開されているので次はこのあたりを読む
https://github.com/huggingface/transformers/blob/238d2e3c44366aba9dc5c770c95475765a6725cb/src/transformers/models/mixtral/modeling_mixtral.py#L688

参考

MoEについて

https://huggingface.co/blog/moe

https://deeplearning.hatenablog.com/entry/moe

moe実装例

https://github.com/huggingface/transformers/blob/238d2e3c44366aba9dc5c770c95475765a6725cb/src/transformers/models/mixtral/modeling_mixtral.py#L688

https://hungyuling.com/blog/fast-mixture-of-experts-in-pytorch/

https://github.com/laekov/fastmoe/blob/master/README.md

https://github.com/lucidrains/mixture-of-experts/blob/master/mixture_of_experts/mixture_of_experts.py

Discussion