🧩

【AWS Trainium 50本ノック #6】Llama3をTrainiumに移植しなおす編

に公開

第 6 章 Llama3をTrainiumに移植しなおす編

本章では以下を仮定します。

  • 第 4 章「NxD対応済モデルの学習編」を実施済みであること
  • 分散学習の基礎知識(第 5 章の内容)

問題 (38-50)

前々章では、tanuki-8b の学習を行いました。tanuki-8b のモデルアーキテクチャは、config.json にあるように「LlamaForCausalLM」です。このアーキテクチャは、 transformers ライブラリの modeling_llama.py 中に定義されています。

一方、NxDT ライブラリ中にも modeling_llama.py が存在します。こちらは、上記 transformers ライブラリの実装を Neuron チップ上での分散学習が可能な形に「移植」したものです。

世の中では日々新しいモデルアーキテクチャが登場しています。transformers ライブラリの models ディレクトリの中には、上記の LlamaForCausalLM 以外にも様々なアーキテクチャの定義コードが格納されています。しかしながら、これらの多くは、まだ NxDT 公式で移植版実装が提供されていません。

今回は NxDT 版の modeling_llama.py が公式提供されていましたが、これが仮に提供されていなかった場合を想定して、モデルの「移植」の練習をしてみましょう。


NxDT 版 modeling_llama.py の中身を見るとわかりますが、基本的にはオリジナル(以下 HF 版と呼びます)の modeling_llama.py で定義されている層のクラスたちを継承し、必要な部分だけオーバーライド(上書き)している構成となっています。⚠️注意:このような構成のため、NxDTライブラリの挙動はtransformersライブラリのバージョンにかなり敏感になってしまっています。NxDTライブラリを使用する際は、transformersライブラリのバージョンは厳格に固定することをお勧めします。

  1. HF 版のmodeling_llama.py の中身を十分に理解してください(注意:HF版も日々更新されていくため、GitHubではなく、現在のPython環境にインストールされているtransformersライブラリの中にあるソースを確認するようにしてください)。どの層がどの層をどの順で呼び出すのか、そのツリー構造を正確に把握してください。Claude, Gemini, ChatGPT等のLLMに modeling_llama.py の全体をコンテキストとして与えた上で、「どのレイヤーがどのレイヤーを呼び出しているのか、全体の処理の流れがどうなっているのか、テキストで簡単なツリー図にまとめてもらえますか?」等と尋ねるのがお勧めです。

    • LLMからの返答例

      LlamaForCausalLM
      ├── LlamaModel
      │   ├── embed_tokens(埋め込み層)
      │   ├── Rotary Position Embedding(RoPE)
      │   ├── LlamaDecoderLayer(デコーダ層)× N層
      │   │   ├── LlamaRMSNorm(正規化)
      │   │   ├── LlamaAttention(自己注意)
      │   │   │   ├── query, key, value 計算
      │   │   │   ├── Rotary Position Embedding 適用
      │   │   │   ├── アテンション重み計算(Scaled Dot-Product Attention)
      │   │   │   └── アテンション出力計算
      │   │   ├── 残差接続(Residual Connection)
      │   │   ├── LlamaRMSNorm(正規化)
      │   │   ├── LlamaMLP(全結合層)
      │   │   │   ├── gate_proj & up_proj(次元変換)
      │   │   │   ├── 活性化関数(Activation Function)
      │   │   │   ├── down_proj(次元縮小)
      │   │   │   └── 出力
      │   │   └── 残差接続(Residual Connection)
      │   ├── LlamaRMSNorm(正規化)
      │   └── 隠れ状態(Hidden States)を返す
      └── lm_head(ロジット計算)
      
  2. train.shの階層に models ディレクトリを新規作成し、この中にmodeling_my_llama.py というファイルを作成します。ここに移植コードを開発していきます。まずは、HF版の modeling_llama.py のクラスを継承して全くそのままとするところから始めましょう。内容を以下としてください。

    from transformers.models.llama.modeling_llama import (
        LlamaForCausalLM as LlamaForCausalLMHF,
        LlamaRotaryEmbedding as LlamaRotaryEmbeddingHF,
        LlamaDecoderLayer as LlamaDecoderLayerHF,
        LlamaAttention as LlamaAttentionHF,
        LlamaRMSNorm as LlamaRMSNormHF,
        LlamaMLP as LlamaMLPHF
    )
    
    class LlamaForCausalLM(LlamaForCausalLMHF):
        pass
    
    class LlamaRotaryEmbedding(LlamaRotaryEmbeddingHF):
        pass
    
    class LlamaDecoderLayer(LlamaDecoderLayerHF):
        pass
    
    class LlamaAttention(LlamaAttentionHF):
        pass
    
    class LlamaRMSNorm(LlamaRMSNormHF):
        pass
    
    class LlamaMLP(LlamaMLPHF):
        pass
    
    
  3. 現時点で何がダメかを確認するため、とりあえず動かしてください。

    • 訓練コードから、突貫で作成した上記のモデル定義を呼び出すように変更します。(参考:公式Docs
      1. training.py を以下のように修正します:

        • 変更前

          from neuronx_distributed_training.lightning_modules.model.hf_models.llama_model import (
              HFLLamaModule,
          )
          ...(中略)...
          model = HFLLamaModule(cfg, trainer)
          
        • 変更後

          from models.my_llama_model import HFMyLLamaModule
          ...(中略)...
          model = HFMyLLamaModule(cfg, trainer)
          
      2. modelsフォルダ内にファイル my_llama_model.py を新規作成し、その中で、上記でインポートする HFMyLLamaModule を以下のように定義します。

        my_llama_model.py
        # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
        # SPDX-License-Identifier: Apache-2.0
        
        import os
        import neuronx_distributed as nxd
        import torch
        from transformers import LlamaConfig
        import sys
        from neuronx_distributed.utils.utils import hardware
        from neuronx_distributed_training.utils import get_dtype, get_attribute_from_cfg
        from torch_neuronx.utils import get_platform_target
        from models.modeling_my_llama import (
            LlamaAttention as CoreAttention,
            LlamaDecoderLayer,
            LlamaForCausalLM,
            LlamaRMSNorm,
            LlamaMLP,
            LlamaRotaryEmbedding
        )
        
        from neuronx_distributed_training.lightning_modules.model.hf_models.base_model import BaseHfModel
        
        class HFMyLLamaModule(BaseHfModel):
            def _get_model(self):
                config = LlamaConfig.from_pretrained(self.config.model.model_config)
                config.use_cache = False
                config.return_dict = False
                config.sequence_parallel_enabled = self.config.distributed_strategy.get("sequence_parallel", False)
                config.qkv_linear = self.config.model.get("qkv_linear", False)
                config.fuse_qkv = self.config.model.get("fuse_qkv", True)
                config.kv_shared_group_size = self.config.distributed_strategy.get("kv_replicator", 1)
                config.max_position_embeddings = self.config.model.get("max_position_embeddings", config.max_position_embeddings)
                config.use_flash_attention = self.config.model.fusions.flash_attention
                config.use_ring_attention = get_attribute_from_cfg(self.config, 'ring_attention', False)
                hardware_type = hardware(get_platform_target())
                if hardware_type==hardware.TRN1:
                    config.lnc = self.config.trainer.get("lnc", 1)
                if hardware_type==hardware.TRN2:
                    config.lnc = self.config.trainer.get("lnc", 2)
                if self.config.model.get('num_layers', -1) != -1:
                    config.num_hidden_layers = self.config.model.get('num_layers')
                if self.config.model.get('hidden_size', -1) != -1:
                    config.hidden_size = self.config.model.get('hidden_size')
                if self.config.model.get('rope_theta', -1) != -1:
                    config.rope_theta = self.config.model.get('rope_theta')
                config.head_dim = get_attribute_from_cfg(self.config, 'hidden_size', config.hidden_size) // config.num_attention_heads # overriding head_dim value, which was set in transformers code
                config.transpose_nki_inputs = self.config.model.get('transpose_nki_inputs', True) # transpose_nki_inputs by default  
        
                if get_attribute_from_cfg(self.config, "peft", False):
                    lora_config = nxd.modules.lora.LoraConfig(
                        lora_rank=get_attribute_from_cfg(self.config, 'lora_rank', 16),
                        lora_alpha=get_attribute_from_cfg(self.config, 'lora_alpha', 32),
                        lora_dropout=get_attribute_from_cfg(self.config, 'lora_dropout', 0.05),
                        bias=get_attribute_from_cfg(self.config, 'lora_bias', "none"),
                        lora_verbose=get_attribute_from_cfg(self.config, 'lora_verbose', True),
                        target_modules=get_attribute_from_cfg(self.config, 'target_modules', ["qkv_proj"]),
                        load_lora_from_ckpt=get_attribute_from_cfg(self.config, 'load_lora_from_ckpt', False),
                        save_lora_base=get_attribute_from_cfg(self.config, 'save_lora_base', False),
                        merge_lora=get_attribute_from_cfg(self.config, 'merge_lora', False),
                        save_lora_config_adapter=get_attribute_from_cfg(self.config, 'save_lora_config_adapter', True),
                        merge_sharded_lora=get_attribute_from_cfg(self.config, 'merge_sharded_lora', False),
                    )
                    self.nxd_config["lora_config"] = lora_config
        
                if self.config.precision.type == "fp32":
                    config.reduce_dtype = get_dtype(self.config.precision.get('parallel_layers_reduce_dtype', 'fp32')) # RS would be in fp32 as there is no implicit downcasting
                    config.torch_dtype = torch.float32
                else:
                    config.reduce_dtype = torch.bfloat16 # default RS type, this wont get downcasted to anything else, so RS will happen at bf16
                    if get_dtype(self.config.precision.get('parallel_layers_reduce_dtype', 'bf16')) == torch.float32:
                        config.reduce_dtype = torch.float64
                    config.torch_dtype = torch.bfloat16
                   
                leaf_module_cls = [LlamaRMSNorm.__name__, LlamaRotaryEmbedding.__name__]
                activation_recompute_modules = []
                recompute_modules = self.config.model.get("activations_checkpoint_recompute", [])
                granularity = self.config.model.get("activations_checkpoint_granularity", None)
        
                if granularity == "selective":
                    for module in recompute_modules:
                        module_obj = getattr(sys.modules[__name__], module, None)
                        if module_obj is not None:
                            activation_recompute_modules.append(module_obj)
                elif granularity == "full":
                    activation_recompute_modules = "full"
                elif not self.config.model.fusions.get("flash_attention", False):
                    activation_recompute_modules.append(CoreAttention) # do CoreAttention checkpointing if flash_attention is off
                else:
                    activation_recompute_modules = None
        
                self.nxd_config["activation_checkpoint_config"] = activation_recompute_modules
                self.nxd_config["pipeline_config"].update(
                    {
                        "transformer_layer_cls": LlamaDecoderLayer,
                        "output_loss_value_spec": (True, False),
                        "input_names": ["input_ids", "attention_mask", "labels"],
                        "leaf_module_cls": leaf_module_cls,
                    }
                )
                include_buffers = True
                return nxd.initialize_parallel_model(self.nxd_config, self.model_provider_func, include_buffers, config)
        
            def model_provider_func(self, config):
                model = LlamaForCausalLM(config)
                # Here we make sure we use the same sine and cosine matrices for all layers.
                # Making use of same tensors would make the CSE algorithm eliminate the lookup call
                # from layers, keeping only lookup from first layer.
                # with torch.no_grad():
                #     cos, sin = self.get_sin_cos_matrix(config)
                #     for layer in model.model.layers:
                #         layer.self_attn.rotary_emb.cos_cached = cos
                #         layer.self_attn.rotary_emb.sin_cached = sin
        
                if os.environ.get("XLA_DOWNCAST_BF16", None) == "0" and config.torch_dtype == torch.bfloat16:
                    model = model.to(torch.bfloat16)
                    
                return model
        
            def get_sin_cos_matrix(self, config):
                head_dim = config.hidden_size // config.num_attention_heads
                base = config.rope_theta
                inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
                t = torch.arange(config.max_position_embeddings, dtype=inv_freq.dtype)
                freqs = torch.einsum("i,j->ij", t, inv_freq)
                # Different from paper, but it uses a different permutation in order to obtain the same calculation
                emb = torch.cat((freqs, freqs), dim=-1)
                return emb.cos()[None, None, :, :].to(torch.float32), emb.sin()[None, None, :, :].to(torch.float32)
        
            def init_weights(self, module, device):
                """
                Re-init weights after partition
                Referred from HF transformers https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L690
                """
                # Last else should always call super().init_weights() to allow initializing
                # pre-defined layers.
                for key, nested_module in module._modules.items():
                    if isinstance(nested_module, LlamaRotaryEmbedding):
                        module._modules[key] = LlamaRotaryEmbedding(nested_module.config, device)
                if isinstance(module, LlamaRMSNorm):
                    module.weight.data.fill_(1.0)
                else:
                    super().init_weights(module, device)
        
        
  4. AOTコンパイルの sbatch コマンドを実施してください。ただし、チェックポイントの読み込みは行わないことにして、フルスクラッチ重みを利用する設定に変更して実施してください。(resume_from_checkpoint: null)

    データセットの前処理ループが終了後、直ちにOOMでプロセスが終了すると予想されます。並列化を行わないとメモリがデバイスに乗り切りません。並列化を行うためには、以下のようにモデル定義に修正を施す必要があります。

  5. モデルに含まれる「パラメータ」をすべて列挙してください。また、それぞれの層のパラメータのサイズを確認してください。

    解説
    • 例えば以下のような方法で確認できます:

      import torch
      from transformers import AutoModel
      
      model = AutoModel.from_pretrained("/fsx/models/Tanuki-8B-dpo-v1.0/")
      
      def print_param_shapes(model: torch.nn.Module):
          total_params = 0
          for name, param in model.named_parameters():
              print(f"{name:<60} {tuple(param.shape)}")
              total_params += param.numel()
          print(f"\nTotal parameters: {total_params:,}")
      
      # 実行
      print_param_shapes(model)
      
    • 結果をまとめると以下のようになります(上記の出力を要約したもの)

      model.embed_tokens.weight                        (65024, 4096)
      model.layers.*.self_attn.q_proj.weight           (4096, 4096)
      model.layers.*.self_attn.k_proj.weight           (1024, 4096)
      model.layers.*.self_attn.v_proj.weight           (1024, 4096)
      model.layers.*.self_attn.o_proj.weight           (4096, 4096)
      model.layers.*.mlp.gate_proj.weight              (14336, 4096)
      model.layers.*.mlp.up_proj.weight                (14336, 4096)
      model.layers.*.mlp.down_proj.weight              (4096, 14336)
      model.layers.*.input_layernorm.weight            (4096,)
      model.layers.*.post_attention_layernorm.weight   (4096,)
      model.norm.weight                                (4096,)
      lm_head.weight                                   (65024, 4096)
      

    上記の中で、パラメータサイズが大きい embed_tokens, self_attn.[qkvo]_proj, mlp.(gate|up|down)_proj, lm_head は、テンソル並列可能な形に変更する必要があります。

  6. 最初に、トークンIDをベクトルに変換する「埋め込み層」をテンソル並列可能な形に変更してください。これには ParallelEmbedding 層を使用します。

    • HF 版のソースを確認すると、LlamaModel__init__ 中で self.embed_tokens というnn.Embedding 層がインスタンス化されています。ここをParallelEmbedding 層で置き換えます。modeling_my_llama.py の中で、LlamaModel__init__ 関数を、以下の変更を適用した版でオーバーライドしてください(同時に、LlamaModel の呼び出し元である LlamaForCausalLM__init__ も、新版の LlamaModel をインスタンス化するためにオーバーライドが必要になります。合わせて、各 __init__ 冒頭の super().__init__ の修正も適宜必要なのでご注意ください):
      • 変更前

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        
      • 変更後(from neuronx_distributed.parallel_layers.layers import ParallelEmbedding が必要)

        self.embed_tokens = ParallelEmbedding(config.vocab_size, config.hidden_size, self.padding_idx, sequence_parallel_enabled=self.config.sequence_parallel_enabled)
        
      • 解説

        • ParallelEmbedding 層(公式Docs)(ソース)は、torch.nn.Embedding と同じ役割の層ですが、そのパラメータテンソルは「語彙数の次元」でTPサイズに分割されて各デバイスに分散して保持されます。
        • sequence_parallel_enabled は、シーケンス並列の使用有無を設定するオプションです。
          • False(デフォルト)の場合:シーケンス並列を使用しません。TP個すべてのデバイスに、本来のシェイプのテンソルが返されます。
          • True の場合:シーケンス並列を使用します。この場合、この層からの出力は「シーケンス並列モード」(前述)で返ってきます。すなわち、TP個それぞれのデバイスには、TP分割後のテンソルが返されます(一般にそれらは互いに異なります)。それらのシーケンス方向の次元は、本来の1/TP倍です。また、通常の (バッチサイズ, シーケンス長, 隠れ次元) という軸順と異なり、0-dimと1-dimが転置された (シーケンス長, バッチサイズ, 隠れ次元) の軸順のテンソルが返ります。
  7. MLP層 LlamaMLP をテンソル並列可能な形に変更してください。これには ColumnParallelLinear 層と RowParallelLinear 層を使用します。

    • HF 版のソースを確認すると、LlamaMLP__init__self.gate_proj, self.up_proj, self.down_proj という3つの nn.Linear 層がインスタンス化されていますが、ここを以下で置き換えます。前項同様、LlamaMLP__init__ 関数を修正版でオーバーライドしてください(また、呼び出し元の LlamaDecoderLayer__init__およびその呼び出し元の LlamaForCausalLM__init__ もオーバーライドが必要です。以降、この括弧内と同様の注意書きは省略します):
      • 変更前

        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
        
      • 変更後(from neuronx_distributed.parallel_layers.layers import ColumnParallelLinear, RowParallelLinear が必要)

        self.gate_proj = ColumnParallelLinear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, gather_output=False, sequence_parallel_enabled=self.config.sequence_parallel_enabled)
        self.up_proj = ColumnParallelLinear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, gather_output=False, sequence_parallel_enabled=self.config.sequence_parallel_enabled)
        self.down_proj = RowParallelLinear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias, input_is_parallel=True, sequence_parallel_enabled=self.config.sequence_parallel_enabled)
        
      • 解説

        • ColumnParallelLinear 層(公式Docs)(ソース)は、torch.nn.Linearと同様の線形層ですが、そのパラメータテンソルは「出力次元」でTPサイズに分割されて各デバイスに分散して保持されます。
          • gather_output は、この層からの出力をTP個のデバイスで分散保持するか否かを設定するパラメータです。
            • False の場合:分散保持されます。すなわち、TP個それぞれのデバイスには、TP分割後のテンソルが返されます(一般にそれらは互いに異なります)。その隠れ次元は、本来の1/TP倍です。
            • True(デフォルト)の場合:分散保持はされません。TP個すべてのデバイスに、本来の出力次元サイズの(分割されていない)テンソルが返されます。
          • sequence_parallel_enabled は、シーケンス並列の使用有無を設定するオプションです。
            • False(デフォルト)の場合:シーケンス並列を使用しません。
            • True の場合:シーケンス並列を使用します。この層に入力されるテンソルは「シーケンス並列モード」で保持されていると仮定され、またこの層からの出力テンソルは「シーケンス並列モード」で返ってきます。
        • RowParallelLinear 層(公式Docs)(ソース)は、torch.nn.Linearと同様の線形層ですが、そのパラメータテンソルは「入力次元」でTPサイズに分割されて各デバイスに分散して保持されます。
          • input_is_parallel は、この層への入力がデバイス間で分散保持されているか否かを伝えるパラメータです。
            • False (デフォルト)の場合:分散保持されていないと仮定されます。
            • True の場合:分散保持されていると仮定されます。入力するテンソルのサイズは、TP分割後のサイズである必要があります。
        • 一般に、線形層が 2つ直列に続く場合、1つ目を ColumnParallelLinear(gather_output=False) とし、2つ目を RowParallelLinear(input_is_parallel=True) とすることで、効率的な計算が可能です。これが「基本パターン」となります。以降の変更箇所でも、このパターンが現れます。
  8. 最後の全結合層 lm_head をテンソル並列可能な形に変更してください。

    • HF 版のソースを確認すると、LlamaForCausalLM__init__self.lm_head という nn.Linear 層がインスタンス化されていますが、これを ColumnParallelLinear 層で置き換えます。

      • 変更前

        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        
      • 変更後

        self.lm_head = ColumnParallelLinear(config.hidden_size, config.vocab_size, bias=False, gather_output=False, sequence_parallel_enabled=self.config.sequence_parallel_enabled)
        
    • 今回は、TP分割が「合流」するポイントはクロスエントロピーを計算する直前となります。「TPを合流した後ソフトマックスクロスエントロピーを計算する関数」として、NxDライブラリにparallel_cross_entropy が用意されています。LlamaForCausalLMforward の最後に loss を計算する部分がありますが、そこを以下のように修正します:

      • 変更前

        loss = None
        if labels is not None:
            loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
        
      • 変更後(from neuronx_distributed.parallel_layers.loss_functions import parallel_cross_entropyが必要)

        if self.config.sequence_parallel_enabled:
            logits = logits.transpose(0, 1).contiguous()
        
        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].clone().contiguous()
            shift_labels = labels[..., 1:].contiguous()
            shift_logits = shift_logits.view(-1, shift_logits.size(-1))
            shift_labels = shift_labels.view(-1)
            shift_labels = shift_labels.to(shift_logits.device)
            loss = parallel_cross_entropy(shift_logits, shift_labels)
            loss = torch.mean(loss)
        
      • 解説

        • parallel_cross_entropy 関数(公式Docs)(ソース)は、「TP分割されている logits」と「(TP分割されていない)labels」からソフトマックスクロスエントロピーロスを計算します。
        • logits.shape == (バッチサイズ, シーケンス長, 語彙サイズ/TP) , labels.shape == (バッチサイズ, シーケンス長) 、あるいは logits.shape == (バッチサイズ * シーケンス長, 語彙サイズ/TP) , labels.shape == (バッチサイズ * シーケンス長) である必要があります。また、labels[..., i]labels[..., i+1] ではなく)に対する予測が logits[..., i, :] であると解釈されます。
        • シーケンス並列が有効の場合には、0-dim(バッチ次元)と 1-dim(シーケンス次元)が入れ替わって logits に保持されているため、最初に transpose を実施しています。
        • 変更前→変更後で「インデクスずらし」のロジックが追加されているように見えますが、変更前のソースでは self.loss_function の内部でインデクスずらしが実行されているため、実質は同じです。
  9. セルフアテンション層 LlamaAttention をテンソル並列可能な形に変更してください。

    • LlamaAttention 層についての基礎知識

      • セルフアテンション層の基本構造についての知識は仮定しますが、その上で、LlamaAttention の特徴を補足します:
        • LlamaAttention では Grouped Query Attention (GQA) がサポートされています。GQA とは、複数の Query ヘッドに対して Key/Value ヘッドをまとめて共有させることで、計算量やメモリ使用量を削減する仕組みです。通常の Multi-Head Attention では Q/K/V のヘッド数が同じですが、GQA では Q の表現力を保ちながら K/V の数を減らします。これにより、推論時の KV キャッシュの負担を大幅に軽減しつつ、性能をほぼ維持できます。
          • 例えば Qヘッド数が 32、KVヘッド数が 8 の場合、Q0, ..., Q31K0, ..., K7V0, ..., V7 が計算されますが、(Q0, K0, V0), (Q1, K0, V0), (Q2, K0, V0), (Q3, K0, V0), (Q4, K1, V1), (Q5, K1, V1), …, (Q30, K7, V7), (Q31, K7, V7) という32組からそれぞれの O が計算されます。
    • LlamaAttention は、大きく分けると以下の4パートに分かれます:

      1. 入力 (hidden_states) から Q, K, V を計算するパート
      2. Q, K に RoPE (Rotary Position Embedding) を適用するパート
      3. Q, K, V から O を計算するパート
      4. O から出力(次の hidden_states)を計算するパート

      変更の必要な点が非常に多いですが、1つずつ解説します。

      • (i) 入力 (hidden_states) から Q, K, V を計算するパート
        • Q, K, V を計算する 3 つの線形層の代わりに、GQAQKVColumnParallelLinear 層を使用します。
          • __init__
            • 変更前

              self.q_proj = nn.Linear(
                  config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
              )
              self.k_proj = nn.Linear(
                  config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
              )
              self.v_proj = nn.Linear(
                  config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
              )
              
            • 変更後(from neuronx_distributed.modules.qkv_linear import GQAQKVColumnParallelLinear が必要)

              self.qkv_proj = GQAQKVColumnParallelLinear(
                  config.hidden_size,
                  [config.num_attention_heads * self.head_dim, config.num_key_value_heads * self.head_dim],
                  bias=config.attention_bias,
                  gather_output=False,
                  kv_size_multiplier=self.config.kv_shared_group_size,
                  fuse_qkv=self.config.fuse_qkv,
                  sequence_parallel_enabled=self.config.sequence_parallel_enabled
              )
              
            • 解説

              • GQAQKVColumnParallelLinear 層(公式Docs)(ソース)は、入力から Q, K, V を線形変換により計算しますが、そのパラメータテンソルは「出力次元」でTPサイズに分割されて各デバイスに分散して保持されます。
                • 第二引数には [q_projの出力次元, k_proj(v_proj)の出力次元] を指定します。
                • 出力は (query_states, key_states, value_states) の tuple で返ってきます。それぞれのテンソルの軸順は [バッチサイズ, シーケンス長, ヘッド数*ヘッド次元] です。
                • kv_size_multiplier に、先述のKV_REPLICATORの値(i.e. KVヘッドの重みを何倍に複製して保持するか)を指定します。
                • fuse_qkv (デフォルト: True) が Trueの場合、Q, K, V を計算する線形変換のパラメータは結合された状態で保持されます(パラメータ名は(weight|bias)_qkvとなります)。Falseの場合は、Q, K, Vごとに別々のパラメータとして保持されます(パラメータ名は(weight|bias)_(q|k|v)。前者の方が計算効率が良いです。
          • __init__ 内、変数定義変更
            • 変更前

              self.num_heads = config.num_attention_heads
              self.num_key_value_heads = config.num_key_value_heads
              self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
              
            • 変更後(要 import neuronx_distributed.parallel_layers.utils as neuronx_dist_utils

              self.num_heads = neuronx_dist_utils.divide(config.num_attention_heads, get_tensor_model_parallel_size())
              self.num_key_value_heads = neuronx_dist_utils.divide(
                  config.num_key_value_heads * self.config.kv_shared_group_size, get_tensor_model_parallel_size()
              )
              self.num_key_value_groups = self.num_heads // self.num_key_value_heads
              
            • 解説

              • self.num_(heads|key_value_heads|key_value_groups) はそれぞれ「注意ヘッド数(Qヘッド数)」「KVヘッド数」「KVグループ数」ですが、それぞれを「TP で割った後の値」に変更しています。これらは forward の処理で参照します。
                • ヘルパー関数 neuronx_dist_utils.divide(x, y)ソース)は、x // y と等価ですが、x % y != 0 の場合に例外を送出する点のみ異なります。
          • forward
            • 変更前

              query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
              key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
              value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
              
            • 変更後

              bsz, q_len, _ = hidden_states.size()
              if self.config.sequence_parallel_enabled:
                  q_len, bsz, _ = hidden_states.size()
                  q_len = q_len * get_tensor_model_parallel_size()
              
              query_states, key_states, value_states = self.qkv_proj(hidden_states)
              query_states, key_states, value_states, seq_len_dim_index = self.permute_qkv_for_attn(
                  query_states, key_states, value_states, bsz, q_len, self.num_heads,
                  self.num_key_value_heads, self.head_dim, self.config
              )
              

              さらに以下ヘルパー関数を LlamaAttention 内に定義します

                def reshape_and_permute_states_for_fa(self, states, bsz, q_len, num_heads, head_dim, use_sequence_parallel):
                    if use_sequence_parallel:
                        return states.view(q_len, bsz, num_heads, head_dim).permute(1, 2, 3, 0)
                    else:
                        return states.view(bsz, q_len, num_heads, head_dim).permute(0, 2, 3, 1)
              
                def permute_qkv_for_attn(
                        self, query_states, key_states, value_states, bsz, q_len, num_heads, num_key_value_heads, head_dim, config
                    ):
              
                    if config.transpose_nki_inputs and config.use_flash_attention:
                        query_states = self.reshape_and_permute_states_for_fa(query_states, bsz, q_len, num_heads, head_dim, config.sequence_parallel_enabled)
                        key_states = self.reshape_and_permute_states_for_fa(key_states, bsz, q_len, num_key_value_heads, head_dim, config.sequence_parallel_enabled)
                        value_states = self.reshape_and_permute_states_for_fa(value_states, bsz, q_len, num_key_value_heads, head_dim, config.sequence_parallel_enabled)
                        dim_index = -1
                    elif config.sequence_parallel_enabled:
                        query_states = query_states.view(q_len, bsz, num_heads, head_dim).permute(1, 2, 0, 3)
                        key_states = key_states.view(q_len, bsz, num_key_value_heads, head_dim).permute(1, 2, 0, 3)
                        value_states = value_states.view(q_len, bsz, num_key_value_heads, head_dim).permute(1, 2, 0, 3)
                        dim_index = -2
                    else:
                        query_states = query_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
                        key_states = key_states.view(bsz, q_len, num_key_value_heads, head_dim).transpose(1, 2)
                        value_states = value_states.view(bsz, q_len, num_key_value_heads, head_dim).transpose(1, 2)
                        dim_index = -2
                    
                    return query_states, key_states, value_states, dim_index
              
            • 解説

              • GQAQKVColumnParallelLinear 層から出力された (query|key|value)_states の軸順を、Attention のコア部分(Q, K, V から O を計算する部分)を計算しやすい順番に変更しています。具体的には、以下の軸順に変更しています:
                • config のフラグ transpose_nki_inputsuse_flash_attention の両方が True である場合:[バッチサイズ, ヘッド数, ヘッド次元, シーケンス長]dim_index = -2
                • 上記以外:[バッチサイズ, ヘッド数, シーケンス長, ヘッド次元]dim_index = -1
      • (ii) Q, K に RoPE (Rotary Position Embedding) を適用するパート
        • インポート

          • 変更前

            from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
            
          • 変更後

            from neuronx_distributed.overrides.transformer_overrides import apply_rotary_pos_emb
            
        • forward

          • 変更前

            query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
            
          • 変更後

            query_states, key_states = apply_rotary_pos_emb(
                query_states, key_states, cos, sin, None, 
                self.config.use_flash_attention, self.config.transpose_nki_inputs
            )
            
        • 解説

          • apply_rotary_pos_embはRoPEを計算する関数ですが、これをHF版からNxD版に差し替えています。NxD版の実装は、use_flash_attention==True かつ transpose_nki_inputs==Trueであるケース(query_state, key_states の軸順が前述の通り入れ替わっている)に対応している点以外はHF版と同じです。
      • (iii) Q, K, V から O を計算するパート
        • まず K, V について、KV_REPLICATOR を反映する(すなわち、複製を行って「Q ヘッド数」に揃える)必要があります。
          • 追加

            # repeat k/v heads if n_kv_heads < n_heads
            key_states = repeat_kv(key_states, self.num_key_value_groups)
            value_states = repeat_kv(value_states, self.num_key_value_groups)
            
          • 解説

            • repeat_kv 関数は、シェイプが (バッチサイズ, KVヘッド数, シーケンス長, ヘッド次元) のテンソルを受け取り、(バッチサイズ, KVヘッド数 * 複製数, シーケンス長, ヘッド次元)にして返します(1-dim が元々 [A, B, C, D] となっている場合、[A, A, A, B, B, B, C, C, C, D, D, D] となって返ってきます)。
            • (なお、HF版ではこのKV複製処理は次項の「変更前」コードの eager_attention_forward 関数内で行われています)
        • いよいよコア部分である Q, K, V → O の計算です。これを GPU で実施する際には、GPUの特性に依存した Flash Attention という手法がデファクトで使用され、素朴に計算するよりもメモリ使用量を著しく削減できます。これの Neuron チップ版が nki_flash_attn_func として実装されています。
          • 変更前

            attention_interface: Callable = eager_attention_forward
            if self.config._attn_implementation != "eager":
                if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
                    logger.warning_once(
                        "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
                        'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
                    )
                else:
                    attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
            
            attn_output, attn_weights = attention_interface(
                self,
                query_states,
                key_states,
                value_states,
                attention_mask,
                dropout=0.0 if not self.training else self.attention_dropout,
                scaling=self.scaling,
                **kwargs,
            )
            
          • 変更後(要 from neuronx_distributed.kernels.flash_attn import nki_flash_attn_func

            attn_output = nki_flash_attn_func(query_states, key_states, value_states, self.config.lnc, transpose_nki_inputs=self.config.transpose_nki_inputs)
            if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
                raise ValueError(
                    f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                    f" {attn_output.size()}"
                )
            
          • 解説

            • nki_flash_attn_funcソース)は、Neuron チップ向けの Flash Attention 実装です。Neuron チップの低レイヤーな制御が可能な、NKI (Neuron Kernel Interface) によって実装されています。
            • 入力される (query|key|value)_states の軸順は以下である必要があります:
              • transpose_nki_inputs==True の場合: [バッチサイズ, ヘッド数, ヘッド次元, シーケンス長]
              • transpose_nki_inputs==False の場合: [バッチサイズ, ヘッド数, シーケンス長, ヘッド次元]
            • 出力される attn_output の軸順は、いつでも [バッチサイズ, ヘッド数, シーケンス長, ヘッド次元] となります。
      • (iv) O から出力(次の hidden_states)を計算するパート
        • __init__

          • 変更前

            self.o_proj = nn.Linear(
                config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
            )
            
          • 変更後

            self.o_proj = RowParallelLinear(
                config.num_attention_heads * self.head_dim,
                config.hidden_size,
                bias=config.attention_bias,
                input_is_parallel=True,
                sequence_parallel_enabled=self.config.sequence_parallel_enabled
            )
            
        • forward 内、o_proj層に通す直前の reshape

          • 変更前

            attn_output = attn_output.reshape(*input_shape, -1).contiguous()
            
          • 変更後

            if self.config.sequence_parallel_enabled:
                attn_output = attn_output.permute(2, 0, 1, 3)
                attn_output = attn_output.reshape(q_len, bsz, self.hidden_size // get_tensor_model_parallel_size())
            else:
                attn_output = attn_output.transpose(1, 2).contiguous()
                attn_output = attn_output.reshape(bsz, q_len, self.hidden_size // get_tensor_model_parallel_size())
            
          • 解説

            • nki_flash_attn_funcからの出力テンソルを、通常のアクティベーションの軸順(シーケンス並列有りの場合は [シーケンス長, バッチサイズ, ヘッド数*ヘッド次元]、無しの場合は [バッチサイズ, シーケンス長, ヘッド数*ヘッド次元])になるようにtransposeとreshapeを実施しています。(TP並列はまだ合流していないため、隠れ次元が TP サイズで割られていることに注意してください。この後のo_projで合流します)
  10. 「AOTコンパイル」を再実行してください。

    • もしエラーが出た場合、デバッグが必要です。一般的なアドバイスを以下に記します。
      • デバッグのTips
        • 冗長なログファイルからエラーの根本原因を見つけるコツ
          • ログはしばしば数千行〜数万行となり、さらに並列で走る多数のプロセスからのログが混じり合うため、ログ行の順番もしばしば乱れます。エラーで落ちた場合に、エラーの原因を表すログ行をスムーズに特定することが重要となります。
            • 最も冒頭に登場する error という文字列を検索:非常にシンプルな方法ですが、多くの場合これでエラーの原因に到達できます。
              • 上記ではっきりしない場合、冒頭に現れる Error:Error |を検索すると多くのケースでエラーの原因に到達することができます。(前者は Python Error を、後者は Neuron コンパイラからのエラーを引っ掛けられます)
            • ログ全体をLLMに読ませてエラーの原因を調べさせる:非常に有効です。全体だと行数が多すぎて読ませられない場合は、上記方法で行範囲を絞ってからLLMに読ませると良いでしょう。
        • printの代わりにxm.master_print
          • print関数でデバッグ情報を出力させようとすると、プログラムの並列実行数だけ出力が重複してしまい、視認性が悪くなります。xm.master_print (要 import torch_xla.core.xla_model as xm)を使用すると、マスターのプロセスからのみ出力が行われるため、デバッグログがすっきりします。
        • 「問題のある層」の特定方法
          • コンパイラが出すエラーからは、どの層に問題があるかは直ちにはわかりません。
          • 層と層の間にprint関数を挟みまくるような方法では、問題のある層を特定できません。
          • 問題のある層を特定するには「層の短絡」がしばしば有効です。例えば、LlamaModelの中でLlamaDecoderLayerを繰り返すforループを丸ごとスキップするように改変してみます。これでエラーが解消されれば、原因はLlamaDecoderLayerの中にあることがわかりますし、エラーが解消されない場合は、最初のエンベディング、あるいは最後の線形層やloss計算に問題があることがわかります。
        • テンソルを覗き見する際の注意
          • 変数に格納されているテンソルの値を確認するために、デバッグ目的で print(あるいは xm.master_print)すると、第1章で解説した通り、その時点で遅延評価が走ってしまいます。デバッグコードの挿入前と挿入後とでコンパイルのタイミングが変化することで、エラー内容が変化してしまう場合があります。
          • 一方、テンソルの .shape, .dtype, .device 等だけの確認であれば、遅延評価は走らず、上記のような問題は生じません。
      • よくあるコンパイラのエラーメッセージ
        • Estimated peak HBM usage (18.179819) exceeds 16GB. Neff won't be able to load on chip
          • 「それぞれのNeuronコアにモデルが載りきらない」ことを表しています。メモリ使用量を減らす(最大シーケンス長やバッチサイズを小さく設定する、同アーキテクチャのより小規模なモデルに変更する)、モデル並列度を高める(TPやPPを大きくする)等を行う必要があります。
        • Couldn't color the DRAM even with 100GB of DRAM space assumption, model needs too much HBM memory !
          • 上記同様、モデルが載りきらないことを表しています。
        • DRAM usage for Internal DRAM tensor exceeds 16GB of device space limit, cannot fit into device, model requires too much HBM memory !
          • 上記同様、モデルが載りきらないことを表しています。
        • Internal tensorizer error: VectorizeDMA:Illegal after shrink dst!
          • 特定の変数に格納されるテンソルのシェイプが、処理のたびに可変になっているようなケースで、上記エラーが生じる場合があります。
        • RuntimeError: (1, 32768, 1, 80) and (1, 32768, 1, 80)
          • シェイプの等しいテンソルAとBについて A * B を計算する際に、それらの shape が揃っておらず計算不可能なとき、上記のようなエラーメッセージが表示されます。ただし、このエラーメッセージは誤っており、本来ならば「Aのshape」と「Bのshape」を並べるべきところに、「Aのshape」を2回並べてしまっていることに留意する必要があります。
        • CCOM WARN No transport found between devices 8 and 7. Possible replica group misconfiguration
          • コア間の通信に問題がある場合に表示されるエラーです。Neuronチップ内のコア間が全結合ではないことに起因して、TPサイズ・PPサイズ・KV_REPLICATORの組み合わせによっては、本エラーが生じる場合があります。
        • ERROR: Unsupported operation: mhlo.set_dimension_size
        • Number of instructions (8146371) is over the threshold (5000000). - Compile under --optlevel=1 to create smaller subgraphs or use pipeline parallelism.
          • サイズの巨大なテンソルを載せようとすると上記エラーが出る場合があります。
        • RuntimeError: Bad StatusOr access: RESOURCE_EXHAUSTED: LoadCollectives: error condition NRT_RESOURCE == rt_status:
          • Neuron SDKのバージョンをアップデートすることで、本エラーが解消する場合があります。
  11. 「学習」を再実行してください。

    最後に、チェックポイントの読み込みに対応しましょう。

  12. チェックポイントの読み込みを行う設定に戻して、再度 AOT コンパイル・学習を実行してください。

    • 実はこのままでは、チェックポイント読み込み時に以下のPythonエラーが生じます:

      RuntimeError: Missing keys when loading state dictionary: model.layers.0.mlp.gate_proj.weight, model.layers.0.mlp.up_proj.weight, model.layers.1.mlp.gate_proj.weight, ...(中略)... model.layers.1.mlp.up_proj.weight,model.layers.30.mlp.gate_proj.weight, model.layers.30.mlp.up_proj.weight, model.layers.31.mlp.gate_proj.weight, model.layers.31.mlp.up_proj.weight
      
    • これは、HF→NxDのチェックポイント変換処理が「MLP 層内の gate_proj 線形層と up_proj 線形層はまとめて gate_up_proj 線形層とする」仕様となっているにも関わらず、上記で移植実装を行った時にそれが考慮されていないためです。その結果として、モデル定義に gate_projup_proj 層が存在するにも関わらず、読み込もうとしているチェックポイントファイルにはそれらのデータが存在しないため、上記エラー文言となります)。

    • 以下のように修正を行います:

      • LlamaMLP__init__ 内を以下のように変更します:

        • 変更前

          self.gate_proj = ColumnParallelLinear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, gather_output=False, sequence_parallel_enabled=self.config.sequence_parallel_enabled)
          self.up_proj = ColumnParallelLinear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, gather_output=False, sequence_parallel_enabled=self.config.sequence_parallel_enabled)
          
          self.act_fn = ACT2FN[config.hidden_act]
          
        • 変更後

          self.gate_up_proj = ColumnParallelLinear(
              self.hidden_size,
              2 * self.intermediate_size,
              stride=2,
              bias=config.mlp_bias,
              gather_output=False,
              sequence_parallel_enabled=self.config.sequence_parallel_enabled,
          )
          
          self.activation_multiply = ActivationMultiplyMLP(config)
          
        • 解説

          • ColumnParallelLinearstride パラメータ(デフォルト: 1)が設定されていると、パラメータテンソルのTP分割のされ方が変化します。具体的には、分割方向の次元全体を「stride 個の区間」に等分した後、それぞれを TP 個に等分します(結果として、それぞれのデバイスが担当するインデクスは stride個の区間に分かれます)。
          • stride=2 と設定したことにより、各デバイスに格納される出テンソルは、その前半が gate_proj の計算結果となり、その後半は up_proj の計算結果となります。
      • 追加で以下を定義します:

        class ActivationMultiplyMLP(torch.nn.Module):
            def __init__(self, config):
                nn.Module.__init__(self)
                self.act_fn = ACT2FN[config.hidden_act]
                self.split_size = config.intermediate_size // get_tensor_model_parallel_size()
            
            def forward(self, x):
                gate_proj, up_proj = x.split(self.split_size, dim=2)
                intermediate_states = self.act_fn(gate_proj) * up_proj
                return intermediate_states
        
      • LlamaMLPforward 内を以下のように修正します:

        • 変更前

          down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
          
        • 変更後

          intermediate_states = self.activation_multiply(self.gate_up_proj(x))
          down_proj = self.down_proj(intermediate_states)
          
  13. 最後に、移植が正常に行えているかどうかを確認するため、テスト用の入力系列を1つ固定して、それを「移植前のモデル」と「移植後のモデル」の forward に通してみて、同じ logits が得られることを確認してください。

    • evalモードで評価することに留意してください。
    • 数値誤差があるため、完全に一致するとは限りません。ただし、明らかに異なる logits が出てしまう場合は、途中のどこかの層の移植に問題があると考えられます。途中の層時点での hidden_states を確認するなどして、デバッグを実施してください。

ここまで実行できれば、モデルの基本的な移植方法は一通り抑えられているはずです。

脚注
  1. 本50本ノックの内容を監修くださった AWS の常世様に、感謝を申し上げます。 ↩︎

KARAKURI Techblog

Discussion