🐷
MixtralSparseMoeBlockを読む
LLM Advent Calendar 2023の記事です
はじめに
Mixtral 8x7Bが優秀みたいですね。
MoEの実装がgithubにあるので見ていきます
MoEについては以前に記事を書いたのでそちらを参考
foward部分
attentionの出力をgateに入力し各expertを選択、選択されたexpertに対して入力を行う部分
expertの選択にはsoftmaxを使い、expertの出力をまとめるのが主な処理となる。
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
"""
入力されるhidden_statesに以下のreshapeが行われる
(batch_size, sequence_length, hidden_dim)
=> (batch_size*sequence_length, hidden_dim)
[[I, am, runing], [I, love, cat]]
=> [I, am, running, I, love, cat]
みたいなイメージ、ここではseq_len=3, batch_size=2
実際には各単語は、attention層のhidden dimのサイズのベクトルとなる
"""
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
"""
expertを選択するためのgate関数に入力
入力shape: (batch*sequence_length, hidden_dim)
出力shape: (batch*sequence_length, expert_num)
"""
router_logits = self.gate(hidden_states)
"""
softmaxをとる
例えば、batch_size=1, sequence_length=3, expert_num=2の場合の出力は、
以下のように、各トークンがどのexpertに割り振られるかの確率になる。
[[0.9, 0.1], => token0はexpert0に0.9、expert1に0.1
[0.8, 0.2], => token1はexpert0に0.8、expert1に0.2
[0.3, 0.7]]
"""
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
"""
sparse moeでは上位k個のexpertを選択する。
routing_wieghtsが以下で、top_k=1のとき、
[[0.9, 0.1],
[0.8, 0.2],
[0.3, 0.7]]
selected_expertsはその確率に応じてexpertが選択される
[[0], => token0はexpert0を選択
[0], => token1はexpert0を選択
[1]] => token2はexpert1を選択
"""
routing_weights, selected_experts = torch.topk(routing_weights,
self.top_k, dim=-1)
"""
選択された要素を正規化
routing_weights=[0.6, 0.3, 0.1]から[0.6, 0.3]が選択された場合、
[0.66..,0.33..]のようになる
この確率はexpertのMLP層で使われ、MLP層の出力がこの比率によって
scaleされることになる。
確率に応じてexpertの出力が優先されるようなイメージ
"""
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
"""
(batch_size * sequence_length, hidden_dim)のzero埋めされたtensorを作成
このfinal_hidden_statesにexpertの出力を追加していく
"""
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim),
dtype=hidden_states.dtype, device=hidden_states.device
)
"""
expert mask用のone_hotを作成
selected_expertが以下で
[[0],
[0],
[1]]
expert_num=2のとき
one_hotは
[[[1,0]], => token0はexpert0
[[1,0]], => token1はexpert0
[[0,1]]] => token2はexpert1
となる。shapeは(batch_size * sequence_length, 1, expert_num)
これを最後と最初を次元を入れ替えて、最終的なexpert_maskは以下のようになる
[[[1,1,0]], => expert0にはtoken0とtoken1が割り振られている
[[0,0,1]]] => expert1にはtoken2が割り振られている
結局、selected_expertsからexpert_maskへ次のような変換が行われたことになる
selected_expertsはトークン毎にどのexpertを選択したか表しており、
expert_maskはexpert毎にどのトークンが割り振られたかを表している
"""
# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(
selected_experts,
num_classes=self.num_experts
).permute(2, 1, 0)
# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
"""
expert_maskが以下で、1回目のループ(expert0について見ているとき)
[[[1,1,0]],
[[0,0,1]]]
expert_mask[0] = [1, 1, 0]
となる。
これは、expert0には、token0とtoken1が割り振られていることを表す。
whereはexpert_maskの値が1となるindexを取り出す処理
expert_mask[expert_idx]=[1, 1, 0]の場合、1番目と2番目が1の値なので、
top_x = [0, 1]となる
"""
idx, top_x = torch.where(expert_mask[expert_idx])
"""
対象のexpertに1つもtokenが割り振られていない場合は何もしない。
"""
if top_x.shape[0] == 0:
continue
# in torch it is faster to index using lists than torch tensors
top_x_list = top_x.tolist()
idx_list = idx.tolist()
# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
"""
hidden_statesはmoeに対する入力
hidden_states.shape=(batch_size, sequence_length, hidden_dim)
top_xは、対象のexpertに割り振られたtokenのindexを表す。
対象がexpert0で、top_x = [0, 1]の場合、
(expert0にはtoken0とtoken1が割り振られていることを表す)
hidden_states[None, top_x_list]で対応するtoken、
すなわち1番目と2番目のtokenを取り出す
hiddens_states=[
token0,
token1,
]
hidden_states[None, top_x_list]に対しreshapeで以下の変形が行われる
(1, batch_size*sequence_length, hidden_dim)
=> (batch_size*sequence_length, hidden_dim)
"""
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
"""
expertに入力
"""
current_hidden_states = expert_layer(
current_state,
routing_weights[top_x_list, idx_list, None])
# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
"""
final_hidden_statesに追加
"""
final_hidden_states.index_add_(
0,
top_x,
current_hidden_states.to(hidden_states.dtype)
)
"""
(batch * sequence_length, n_experts)
=> (batch, sequence_length, n_experts)
"""
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
"""
expertの出力をまとめたものと、gateの出力を返す
gateの出力はロードバランスのためのlossに使われる
"""
return final_hidden_states, router_logits
expert
デフォルトの活性化関数はsilu
llama2のMLPと同じSwiGLUとなっている
最後のreturnするときに、routing_weightsの要素ごとの積を取っている。
ここでのrouting_weightsはsequenceがexpertを選択した確率を表し、expertの出力を確率に従ってスケールするようなイメージ。
loss
expertの選択が特定のexpertのみに偏ってしまう場合があり、これは学習が非効率になってしまう。これを軽減するため、均等に割り振られることを期待するlossを提案している。
論文中のlossに従い、load_balancing_lossを実装している
追記
動作確認してたら値がおかしかったので確認すると、issueがすでにあがってました。
v4.36.2には取り込まれてないので注意
mainには取り込まれており、以下の最新commitの時点では、取り込まれているので以下の内容で修正してます。
"""
gate_logitsはlayer数ごとのtuple
(layer0のgate_logit, layer1のgate_logit, ....)
みたいな感じ
さらにlayer1のgate_logitは
[
[0.1, 0.2, 0.3], => I
[0.2, 0.3, 0.4] => am
[0.3, 0.4, 0.5] => running
...
]
みたいなイメージ
"""
if gate_logits is None or not isinstance(gate_logits, tuple):
return 0
if isinstance(gate_logits, tuple):
compute_device = gate_logits[0].device
concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
"""
layerごとのtokenがどのexpertに割り振られるかの確率をsoftmaxにより算出
[
[0.6645, 0.3097, 0.0258], => layer0のtoken0がexpert1,2,3に割り振られる確率
[0.4228, 0.3043, 0.2729], => layer0のtoken1がexpert1,2,3に割り振られる確率
[0.4055, 0.4224, 0.1722], => layer1のtoken0がexpert1,2,3に割り振られる確率
[0.3974, 0.2495, 0.3531], => layer1のtoken1がexpert1,2,3に割り振られる確率
]
"""
routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
"""
top_kで上位kを選択
[
[2, 1],=> layer0のtoken0はexpert2, 1に割り振られる
[0, 1],=> layer0のtoken1はexpert0, 1に割り振られる
[2, 0],=> layer1のtoken0はexpert2, 0に割り振られる
[2, 1] => layer1のtoken1はexpert2, 1に割り振られる
]
"""
_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
# treat `top_k` as tokens (shape is `top_k X [batch_size X sequence_length]`)
selected_experts = selected_experts.reshape(-1)
"""
selected_expertsから以下のone_hotを作成
[
[0, 0, 1],=> layer0のtoken0がexpert2に割り振られる
[0, 1, 0],=> layer0のtoken0がexpert1に割り振られる
[1, 0, 0],=> layer0のtoken1がexpert0に割り振られる
[0, 1, 0],=> layer0のtoken1がexpert1に割り振られる
[0, 0, 1],=> layer1のtoken0がexpert2に割り振られる
[1, 0, 0],=> layer1のtoken0がexpert0に割り振られる
[0, 0, 1],=> layer1のtoken1がexpert2に割り振られる
[0, 1, 0] => layer1のtoken1がexpert1に割り振られる
]
"""
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
"""
expert_mask=[
[0, 0, 1],=> layer0のtoken0がexpert2に割り振られる
[1, 0, 0],=> layer1のtoken0がexpert0に割り振られる
]
に対し、各次元での最大を求める。
maxの結果、expert_mask=[1, 0, 1]となる
これは、layer、tokenの全体を通して、expert0とexpert2が選ばれていることを表す。
このあとのexpert_maskに対するtorch.meanにより
各エキスパートに割り当てられたトークンの平均数を計算する
例1
expert_mask=[1, 0, 1]、mean=2/3=0.666..
3つのexpertのうち、2つにtokenが割り振られることを表す
例2
expert_mask=[1, 1, 1]、mean=3/3=1.0
3つのexpertのうち、3つにtokenが割り振られることを表す
例3
expert_mask=[0, 0, 0, 1]、mean=1/4=0.25
4つのexpertのうち、1つにtokenが割り振られることを表す
すなわち、tokens_per_expertは
各エキスパートに割り当てられたトークンの平均数を表す
"""
expert_mask = torch.max(expert_mask, dim=-2).values
# Compute the percentage of tokens routed to each experts
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
"""
router_prob_per_group_and_expertは、
各エキスパートにtokenが割り当てられる平均確率を表す
routing_weightsはgate_logitsにsoftmaxをとったもの。
routing_weightsが以下の場合、
[
[0.2353, 0.4763, 0.2884],
[0.0707, 0.6192, 0.3101]
]
torch.mean,dim=0によって、縦の軸で平均をとることになり、
router_prob_per_expert = [0.1530, 0.5477, 0.2993]となる。
expert0にtokenが割り振られる確率の平均は0.1530
expert1にtokenが割り振られる確率の平均は0.5477
である。
すなわち、router_prob_per_expertは
各エキスパートにtokenが割り当てられる平均確率を表す
"""
# Compute the average probability of routing to these experts
router_prob_per_expert = torch.mean(routing_weights, dim=0)
"""
論文中の数式との対応は
N=num_experts
f=tokens_per_expert
P=router_prob_per_expert
"""
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(-1))
return overall_loss * num_experts
所感
処理を追いかけてみると意外とシンプルでわかりやすかった
より効率の良い計算を行うMoEもあるみたいなので、こっちも参考になりそう
Discussion