Mixture of expertsの簡単な実装をしてみる
LLM Advent Calendar 2023の記事です
mixture of expertsの調査と簡単な実装を行ってMoEを理解していきます。
実装については並列化や計算効率の向上などの部分は複雑なので、それら取り除いた簡単なもの
Mixture of expertsについて
特定のタスクに特化したexpertを複数用意し、入力に対してexpertを切り替えることで性能を上げる手法。expertを選択するgate部分とexpert部分からなる。
decoder型のtransformerの場合、Mixture of expertsはattention層のあとのFFNに対して適応される。
(論文の図を参照)
sparse MoE
gate関数が上位k個のexpertを選択する。選択するexpert以外の部分は0になるようなスパース性を持つ。 0が含まれることで選択されないexpertでは計算する必要がないため、計算コストが削減される
通常のMoEでは、特定のexpertが多く選択されてしまうことがあり、これにより学習が非効率なものとなってしまう。これを軽減するため、auxiliary lossを導入している
Nはexpertの総数
mixtralのlossの実装はこんな感じ
ちなみに、mixtralではexpertの部分にMLPを使っている
ST-MoE
Stable and Transferable Mixture-of-Expertsの提案
(設計ガイドとして良さそうなのであとでちゃんと読む)
auxiliary lossに加え、z-lossを導入している。
Bはトークン数、Nはexpertの総数、xはgateの出力
あるexpertに割り振られる値が大きくなりすぎるのを防ぐことを期待するlossとなっている。
実装は以下で公開されている
実装
今回はsparseでなく、もっと単純なもので動きを確認する。
import torch
import torch.nn as nn
import torch.nn.functional as F
class MoE(nn.Module):
def __init__(self,
embed_size=128,
num_experts=4,
expert_hidden_size=256,
gate_hidden_size=256,
device="auto",
):
super(MoE, self).__init__()
self.embed_size = embed_size
self.num_experts = num_experts
expert = nn.Sequential(
nn.Linear(embed_size, expert_hidden_size),
nn.ReLU(),
nn.Linear(expert_hidden_size, embed_size)
)
self.experts = nn.ModuleList([expert for _ in range(num_experts)])
self.gate = nn.Sequential(
nn.Linear(embed_size, gate_hidden_size),
nn.ELU(),
nn.Linear(gate_hidden_size, num_experts),
)
def forward(self, x):
"""
Forward pass for MoE
:param x: Input tensor of shape (batch_size, seq_len, embed_size)
:return: Output tensor. shape (batch_size, seq_len, embed_size)
"""
# batch_size, seq_len, _ = x.size()
gating_scores = self.gate(x)
gating_weights = F.softmax(gating_scores, dim=2) # (batch_size, seq_len, num_experts)
print('gating_weights', gating_weights.shape, gating_weights)
expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=2) # (batch_size, seq_len, num_experts, embed_size)
## gating_weights.unsqueeze(-1) # (batch_size, seq_len, num_experts, 1)
output = torch.sum(gating_weights.unsqueeze(-1) * expert_outputs, dim=2) # (batch_size, seq_len, embed_size)
return output
実装自体はシンプルで、gate関数の出力をsoftmaxを取り、その比率を各expertの出力に対して適応する。gateによってexpertが優遇されるイメージ
例えば、以下のような例を入力とする
sequence_length=3
埋め込みサイズ=4
[I am running]
=> [
[0.4081, 0.7244, 0.0221, 0.1677], => I
[0.2508, 0.8103, 0.3783, 0.0399], => am
[0.6584, 0.4243, 0.5941, 0.6102] => running
]
(実際には単語の埋め込み次元ではなく、attention weightのhiddenになるので注意)
expert=2とすると、
入力は、shape=(batch_size,sequence_length, 埋め込みサイズ)
gating_weightsは、sequence_length毎にどのexpertに割り振られるかの確率となっている。
出力は、入力と同様のshape
input: torch.Size([1, 3, 4])
tensor([[[0.4081, 0.7244, 0.0221, 0.1677],
[0.2508, 0.8103, 0.3783, 0.0399],
[0.6584, 0.4243, 0.5941, 0.6102]]])
gating_weights torch.Size([1, 3, 2])
tensor([[[0.5903, 0.4097],
[0.6000, 0.4000],
[0.6028, 0.3972]]], grad_fn=<SoftmaxBackward0>)
output: torch.Size([1, 3, 4])
tensor([[[ 0.5622, -0.0714, 0.0197, -0.2871],
[ 0.6389, -0.1067, -0.0396, -0.2993],
[ 0.4812, -0.0484, 0.0448, -0.4018]]], grad_fn=<SumBackward1>)
所感
簡単な実装を行ってMoEについて理解が深まった
transformerに組み込まれるMoEの場合、学習時にattention層とexpertやgate部分は同時に学習が行われる。専門家を用意して割り振りを学習するようなイメージと異なるので、専門家感があまりない気がするというか、一応loss関数であるexpertだけに集中しないようになどの工夫はあるものの、ちゃんとタスク毎や特徴に分かれるようにexpertが生まれるのだろうか...
mixtralの実装が公開されているので次はこのあたりを読む
参考
MoEについて
moe実装例
Discussion