🐷

MixtralSparseMoeBlockを読む

2023/12/21に公開

LLM Advent Calendar 2023の記事です
https://qiita.com/advent-calendar/2023/llm

はじめに

Mixtral 8x7Bが優秀みたいですね。
MoEの実装がgithubにあるので見ていきます

MoEについては以前に記事を書いたのでそちらを参考
https://zenn.dev/if001/articles/40917524959913

foward部分

attentionの出力をgateに入力し各expertを選択、選択されたexpertに対して入力を行う部分
expertの選択にはsoftmaxを使い、expertの出力をまとめるのが主な処理となる。

https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/models/mixtral/modeling_mixtral.py#L712

 def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        batch_size, sequence_length, hidden_dim = hidden_states.shape
	"""
	入力されるhidden_statesに以下のreshapeが行われる
	(batch_size, sequence_length, hidden_dim)
	=>  (batch_size*sequence_length, hidden_dim)
	
	[[I, am, runing], [I, love, cat]]
	=> [I, am, running, I, love, cat]
	みたいなイメージ、ここではseq_len=3, batch_size=2
	実際には各単語は、attention層のhidden dimのサイズのベクトルとなる
	"""
        hidden_states = hidden_states.view(-1, hidden_dim)
        # router_logits: (batch * sequence_length, n_experts)
	
	"""
	expertを選択するためのgate関数に入力
	入力shape: (batch*sequence_length, hidden_dim)
	出力shape: (batch*sequence_length, expert_num)
	"""
        router_logits = self.gate(hidden_states)

	"""
	softmaxをとる
	例えば、batch_size=1, sequence_length=3, expert_num=2の場合の出力は、
	以下のように、各トークンがどのexpertに割り振られるかの確率になる。

	[[0.9, 0.1], => token0はexpert0に0.9、expert1に0.1
	[0.8, 0.2], => token1はexpert0に0.8、expert1に0.2
	[0.3, 0.7]]
	"""
        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
	
	"""
	sparse moeでは上位k個のexpertを選択する。
	
	routing_wieghtsが以下で、top_k=1のとき、
	[[0.9, 0.1],
	[0.8, 0.2],
	[0.3, 0.7]]
	
	selected_expertsはその確率に応じてexpertが選択される
	[[0], => token0はexpert0を選択
        [0],  => token1はexpert0を選択
        [1]]  => token2はexpert1を選択
	"""
        routing_weights, selected_experts = torch.topk(routing_weights, 
	                        self.top_k, dim=-1)
	
	"""
	選択された要素を正規化
	routing_weights=[0.6, 0.3, 0.1]から[0.6, 0.3]が選択された場合、
	
	[0.66..,0.33..]のようになる
	
	この確率はexpertのMLP層で使われ、MLP層の出力がこの比率によって
	scaleされることになる。
	確率に応じてexpertの出力が優先されるようなイメージ	
	"""
        routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
	
        # we cast back to the input dtype
        routing_weights = routing_weights.to(hidden_states.dtype)

	"""
	(batch_size * sequence_length, hidden_dim)のzero埋めされたtensorを作成
	このfinal_hidden_statesにexpertの出力を追加していく
	"""
        final_hidden_states = torch.zeros(
            (batch_size * sequence_length, hidden_dim), 
	    dtype=hidden_states.dtype, device=hidden_states.device
        )


	"""
	expert mask用のone_hotを作成
	
	selected_expertが以下で
	[[0],
        [0],
        [1]] 
	expert_num=2のとき
	
	one_hotは
	[[[1,0]], => token0はexpert0
	[[1,0]],  => token1はexpert0
	[[0,1]]]  => token2はexpert1
	となる。shapeは(batch_size * sequence_length, 1, expert_num)
	
	これを最後と最初を次元を入れ替えて、最終的なexpert_maskは以下のようになる
	[[[1,1,0]], => expert0にはtoken0とtoken1が割り振られている
	[[0,0,1]]]  => expert1にはtoken2が割り振られている
	
	結局、selected_expertsからexpert_maskへ次のような変換が行われたことになる
	selected_expertsはトークン毎にどのexpertを選択したか表しており、
	expert_maskはexpert毎にどのトークンが割り振られたかを表している
	"""
        # One hot encode the selected experts to create an expert mask
        # this will be used to easily index which expert is going to be sollicitated
        expert_mask = torch.nn.functional.one_hot(
	                  selected_experts,
	                  num_classes=self.num_experts
			  ).permute(2, 1, 0)

        # Loop over all available experts in the model and perform the computation on each expert
        for expert_idx in range(self.num_experts):
            expert_layer = self.experts[expert_idx]
	    """
	    expert_maskが以下で、1回目のループ(expert0について見ているとき)
	    [[[1,1,0]],
	    [[0,0,1]]]
	    
	    expert_mask[0] = [1, 1, 0]
	    となる。
	    これは、expert0には、token0とtoken1が割り振られていることを表す。
	    
	    whereはexpert_maskの値が1となるindexを取り出す処理
	    expert_mask[expert_idx]=[1, 1, 0]の場合、1番目と2番目が1の値なので、
	    top_x = [0, 1]となる
	    """
            idx, top_x = torch.where(expert_mask[expert_idx])

	    """
	    対象のexpertに1つもtokenが割り振られていない場合は何もしない。
	    """
            if top_x.shape[0] == 0:
                continue

            # in torch it is faster to index using lists than torch tensors
            top_x_list = top_x.tolist()
            idx_list = idx.tolist()

            # Index the correct hidden states and compute the expert hidden state for
            # the current expert. We need to make sure to multiply the output hidden
            # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
	    
	    """
	    hidden_statesはmoeに対する入力
	    hidden_states.shape=(batch_size, sequence_length, hidden_dim)
	    
	    top_xは、対象のexpertに割り振られたtokenのindexを表す。
	    対象がexpert0で、top_x = [0, 1]の場合、
	    (expert0にはtoken0とtoken1が割り振られていることを表す)
	    
	    hidden_states[None, top_x_list]で対応するtoken、
	    すなわち1番目と2番目のtokenを取り出す
	    
	    hiddens_states=[
	    token0,
	    token1,
	    ]
	    
	    hidden_states[None, top_x_list]に対しreshapeで以下の変形が行われる
	    (1, batch_size*sequence_length, hidden_dim)
	    => (batch_size*sequence_length, hidden_dim)
	    """
            current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
	    
	    """
	    expertに入力
	    """
            current_hidden_states = expert_layer(
	                              current_state, 
		                      routing_weights[top_x_list, idx_list, None])

            # However `index_add_` only support torch tensors for indexing so we'll use
            # the `top_x` tensor here.
	    """
	    final_hidden_statesに追加
	    """
            final_hidden_states.index_add_(
	                              0,                             
	                              top_x,
	                              current_hidden_states.to(hidden_states.dtype)
				      )
	"""
	(batch * sequence_length, n_experts)
	=> (batch, sequence_length, n_experts)
	"""
        final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
	
	"""
	expertの出力をまとめたものと、gateの出力を返す
	gateの出力はロードバランスのためのlossに使われる
	"""
        return final_hidden_states, router_logits

expert

https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/models/mixtral/modeling_mixtral.py#L664

デフォルトの活性化関数はsilu
llama2のMLPと同じSwiGLUとなっている

最後のreturnするときに、routing_weightsの要素ごとの積を取っている。
ここでのrouting_weightsはsequenceがexpertを選択した確率を表し、expertの出力を確率に従ってスケールするようなイメージ。

https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/models/mixtral/modeling_mixtral.py#L679

loss

expertの選択が特定のexpertのみに偏ってしまう場合があり、これは学習が非効率になってしまう。これを軽減するため、均等に割り振られることを期待するlossを提案している。

論文中のlossに従い、load_balancing_lossを実装している


https://arxiv.org/abs/2101.03961

f_iexpert_iにトークンが割り振られた割合
P_iexpert_iにトークンが割り振られる可能性の確率

追記
動作確認してたら値がおかしかったので確認すると、issueがすでにあがってました。
https://github.com/huggingface/transformers/issues/28093

v4.36.2には取り込まれてないので注意
mainには取り込まれており、以下の最新commitの時点では、取り込まれているので以下の内容で修正してます。

https://github.com/huggingface/transformers/blob/3b7675b2b844b02d4821b827871a21ad16dd446c/src/transformers/models/mixtral/modeling_mixtral.py#L77

    """
    gate_logitsはlayer数ごとのtuple
    
    (layer0のgate_logit, layer1のgate_logit, ....)
    みたいな感じ
    
    さらにlayer1のgate_logitは
    [
      [0.1, 0.2, 0.3], => I 
      [0.2, 0.3, 0.4] => am
      [0.3, 0.4, 0.5] => running
      ...
    ]
    みたいなイメージ
    """
    
    if gate_logits is None or not isinstance(gate_logits, tuple):
        return 0

    if isinstance(gate_logits, tuple):
        compute_device = gate_logits[0].device
        concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)

    """
    layerごとのtokenがどのexpertに割り振られるかの確率をsoftmaxにより算出
    
    [
      [0.6645, 0.3097, 0.0258], => layer0のtoken0がexpert1,2,3に割り振られる確率
      [0.4228, 0.3043, 0.2729], => layer0のtoken1がexpert1,2,3に割り振られる確率
      [0.4055, 0.4224, 0.1722], => layer1のtoken0がexpert1,2,3に割り振られる確率
      [0.3974, 0.2495, 0.3531], => layer1のtoken1がexpert1,2,3に割り振られる確率
    ]
    """
    routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)

    """
    top_kで上位kを選択
    [
     [2, 1],=> layer0のtoken0はexpert2, 1に割り振られる
     [0, 1],=> layer0のtoken1はexpert0, 1に割り振られる
     [2, 0],=> layer1のtoken0はexpert2, 0に割り振られる
     [2, 1] => layer1のtoken1はexpert2, 1に割り振られる
    ]
    """
    _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)

    # treat `top_k` as tokens (shape is `top_k X [batch_size X sequence_length]`)
    selected_experts = selected_experts.reshape(-1)

    """
    selected_expertsから以下のone_hotを作成
    [
     [0, 0, 1],=> layer0のtoken0がexpert2に割り振られる
     [0, 1, 0],=> layer0のtoken0がexpert1に割り振られる
     [1, 0, 0],=> layer0のtoken1がexpert0に割り振られる
     [0, 1, 0],=> layer0のtoken1がexpert1に割り振られる
     
     [0, 0, 1],=> layer1のtoken0がexpert2に割り振られる
     [1, 0, 0],=> layer1のtoken0がexpert0に割り振られる
     [0, 0, 1],=> layer1のtoken1がexpert2に割り振られる
     [0, 1, 0] => layer1のtoken1がexpert1に割り振られる
    ]
    """
    expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
    
    """
    expert_mask=[
     [0, 0, 1],=> layer0のtoken0がexpert2に割り振られる
     [1, 0, 0],=> layer1のtoken0がexpert0に割り振られる
    ]
    に対し、各次元での最大を求める。
    maxの結果、expert_mask=[1, 0, 1]となる
    これは、layer、tokenの全体を通して、expert0とexpert2が選ばれていることを表す。
    
    このあとのexpert_maskに対するtorch.meanにより
    各エキスパートに割り当てられたトークンの平均数を計算する
    
    例1
    expert_mask=[1, 0, 1]、mean=2/3=0.666..
    3つのexpertのうち、2つにtokenが割り振られることを表す
    
    例2
    expert_mask=[1, 1, 1]、mean=3/3=1.0
    3つのexpertのうち、3つにtokenが割り振られることを表す
    
    例3
    expert_mask=[0, 0, 0, 1]、mean=1/4=0.25
    4つのexpertのうち、1つにtokenが割り振られることを表す
    
    すなわち、tokens_per_expertは
    各エキスパートに割り当てられたトークンの平均数を表す
    """
    expert_mask = torch.max(expert_mask, dim=-2).values

    # Compute the percentage of tokens routed to each experts
    tokens_per_expert = torch.mean(expert_mask.float(), dim=0)


    """
    router_prob_per_group_and_expertは、
    各エキスパートにtokenが割り当てられる平均確率を表す
    
    routing_weightsはgate_logitsにsoftmaxをとったもの。
    routing_weightsが以下の場合、
    [
    [0.2353, 0.4763, 0.2884], 
     [0.0707, 0.6192, 0.3101]
    ]
    torch.mean,dim=0によって、縦の軸で平均をとることになり、
    router_prob_per_expert = [0.1530, 0.5477, 0.2993]となる。
    expert0にtokenが割り振られる確率の平均は0.1530
    expert1にtokenが割り振られる確率の平均は0.5477
    である。
    
    すなわち、router_prob_per_expertは
    各エキスパートにtokenが割り当てられる平均確率を表す
    """
    # Compute the average probability of routing to these experts
    router_prob_per_expert = torch.mean(routing_weights, dim=0)
    
    """
    論文中の数式との対応は
    N=num_experts
    f=tokens_per_expert
    P=router_prob_per_expert
    """
    overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(-1))
    return overall_loss * num_experts

所感

処理を追いかけてみると意外とシンプルでわかりやすかった

より効率の良い計算を行うMoEもあるみたいなので、こっちも参考になりそう
https://hungyuling.com/blog/fast-mixture-of-experts-in-pytorch/

Discussion