🧩

AWS Trainium 50 Exercises #6: Re-porting Llama 3 to Trainium

に公開

Chapter 6 — Re-porting Llama 3 to Trainium

This chapter assumes the following:

  • You have completed Chapter 4 “Training NxD-ready models”.
  • You have the basic knowledge of distributed training (contents of Chapter 5).

Problems (38–50)

In the previous-previous chapter we trained tanuki-8b. The model architecture of tanuki-8b is LlamaForCausalLM as shown in its config.json. This architecture is defined in the modeling_llama.py in the transformers library.

On the other hand, there is also a modeling_llama.py in the NxDT (NeuronX Distributed Training) library. That file is the version “ported” from the transformers implementation so that the model can be trained in a distributed way on Neuron chips.

New model architectures appear every day. The models directory of the transformers library contains many architecture implementations in addition to LlamaForCausalLM. However, most of those implementations do not yet have official NxDT-ported versions.

In this chapter we do have an official NxDT modeling_llama.py, but assume the case where such a port is not provided and practice how to port a model yourself.


If you inspect the NxDT modeling_llama.py, you will see that basically it inherits the layer classes defined in the original (hereafter “HF version”) modeling_llama.py and overrides only the necessary parts. ⚠️Note: Because of this inheritance/override structure, the NxDT library behavior is highly sensitive to the transformers library version. When using NxDT, we strongly recommend pinning the transformers version strictly.

  1. Study the HF version of modeling_llama.py thoroughly. (Note: HF sources are updated frequently — check the source code in the transformers package installed in your current Python environment rather than relying on GitHub.) Precisely understand which layer calls which layer and in what order — the full tree structure. It is recommended to feed the entire modeling_llama.py source as context to an LLM such as Claude, Gemini, or ChatGPT, and ask something like: “Can you draw a simple textual tree diagram showing which layer calls which layer and the overall flow of computation?”

    • Example answer from an LLM:

      LlamaForCausalLM
      ├── LlamaModel
      │   ├── embed_tokens (embedding layer)
      │   ├── Rotary Position Embedding (RoPE)
      │   ├── LlamaDecoderLayer (decoder layer) × N
      │   │   ├── LlamaRMSNorm (normalization)
      │   │   ├── LlamaAttention (self-attention)
      │   │   │   ├── query, key, value computation
      │   │   │   ├── apply Rotary Position Embedding
      │   │   │   ├── attention weight computation (Scaled Dot-Product Attention)
      │   │   │   └── attention output computation
      │   │   ├── residual connection
      │   │   ├── LlamaRMSNorm (normalization)
      │   │   ├── LlamaMLP (feed-forward)
      │   │   │   ├── gate_proj & up_proj (dim transforms)
      │   │   │   ├── activation function
      │   │   │   ├── down_proj (dim reduction)
      │   │   │   └── output
      │   │   └── residual connection
      │   ├── LlamaRMSNorm (normalization)
      │   └── return hidden states
      └── lm_head (logit computation)
      
  2. Create a new models directory under the train.sh hierarchy and inside it create a file named modeling_my_llama.py. You will develop the porting code here. Start by inheriting the HF version classes and making them identical (no overrides yet). Use the following content as a starting point:

    from transformers.models.llama.modeling_llama import (
        LlamaForCausalLM as LlamaForCausalLMHF,
        LlamaRotaryEmbedding as LlamaRotaryEmbeddingHF,
        LlamaDecoderLayer as LlamaDecoderLayerHF,
        LlamaAttention as LlamaAttentionHF,
        LlamaRMSNorm as LlamaRMSNormHF,
        LlamaMLP as LlamaMLPHF
    )
    
    class LlamaForCausalLM(LlamaForCausalLMHF):
        pass
    
    class LlamaRotaryEmbedding(LlamaRotaryEmbeddingHF):
        pass
    
    class LlamaDecoderLayer(LlamaDecoderLayerHF):
        pass
    
    class LlamaAttention(LlamaAttentionHF):
        pass
    
    class LlamaRMSNorm(LlamaRMSNormHF):
        pass
    
    class LlamaMLP(LlamaMLPHF):
        pass
    
    
  3. Run it as-is to see what currently fails.

    • Modify your training code so it calls the hastily created model definition above. (Reference: official Docs)

      1. Modify training.py like this:

        • Before:

          from neuronx_distributed_training.lightning_modules.model.hf_models.llama_model import (
              HFLLamaModule,
          )
          ...(omitted)...
          model = HFLLamaModule(cfg, trainer)
          
        • After:

          from models.my_llama_model import HFMyLLamaModule
          ...(omitted)...
          model = HFMyLLamaModule(cfg, trainer)
          
      2. Create a new file my_llama_model.py inside the models folder and define the HFMyLLamaModule that the above import expects. Example:

        my_llama_model.py
        # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
        # SPDX-License-Identifier: Apache-2.0
        
        import os
        import neuronx_distributed as nxd
        import torch
        from transformers import LlamaConfig
        import sys
        from neuronx_distributed.utils.utils import hardware
        from neuronx_distributed_training.utils import get_dtype, get_attribute_from_cfg
        from torch_neuronx.utils import get_platform_target
        from models.modeling_my_llama import (
            LlamaAttention as CoreAttention,
            LlamaDecoderLayer,
            LlamaForCausalLM,
            LlamaRMSNorm,
            LlamaMLP,
            LlamaRotaryEmbedding
        )
        
        from neuronx_distributed_training.lightning_modules.model.hf_models.base_model import BaseHfModel
        
        class HFMyLLamaModule(BaseHfModel):
            def _get_model(self):
                config = LlamaConfig.from_pretrained(self.config.model.model_config)
                config.use_cache = False
                config.return_dict = False
                config.sequence_parallel_enabled = self.config.distributed_strategy.get("sequence_parallel", False)
                config.qkv_linear = self.config.model.get("qkv_linear", False)
                config.fuse_qkv = self.config.model.get("fuse_qkv", True)
                config.kv_shared_group_size = self.config.distributed_strategy.get("kv_replicator", 1)
                config.max_position_embeddings = self.config.model.get("max_position_embeddings", config.max_position_embeddings)
                config.use_flash_attention = self.config.model.fusions.flash_attention
                config.use_ring_attention = get_attribute_from_cfg(self.config, 'ring_attention', False)
                hardware_type = hardware(get_platform_target())
                if hardware_type==hardware.TRN1:
                    config.lnc = self.config.trainer.get("lnc", 1)
                if hardware_type==hardware.TRN2:
                    config.lnc = self.config.trainer.get("lnc", 2)
                if self.config.model.get('num_layers', -1) != -1:
                    config.num_hidden_layers = self.config.model.get('num_layers')
                if self.config.model.get('hidden_size', -1) != -1:
                    config.hidden_size = self.config.model.get('hidden_size')
                if self.config.model.get('rope_theta', -1) != -1:
                    config.rope_theta = self.config.model.get('rope_theta')
                config.head_dim = get_attribute_from_cfg(self.config, 'hidden_size', config.hidden_size) // config.num_attention_heads # overriding head_dim value, which was set in transformers code
                config.transpose_nki_inputs = self.config.model.get('transpose_nki_inputs', True) # transpose_nki_inputs by default  
        
                if get_attribute_from_cfg(self.config, "peft", False):
                    lora_config = nxd.modules.lora.LoraConfig(
                        lora_rank=get_attribute_from_cfg(self.config, 'lora_rank', 16),
                        lora_alpha=get_attribute_from_cfg(self.config, 'lora_alpha', 32),
                        lora_dropout=get_attribute_from_cfg(self.config, 'lora_dropout', 0.05),
                        bias=get_attribute_from_cfg(self.config, 'lora_bias', "none"),
                        lora_verbose=get_attribute_from_cfg(self.config, 'lora_verbose', True),
                        target_modules=get_attribute_from_cfg(self.config, 'target_modules', ["qkv_proj"]),
                        load_lora_from_ckpt=get_attribute_from_cfg(self.config, 'load_lora_from_ckpt', False),
                        save_lora_base=get_attribute_from_cfg(self.config, 'save_lora_base', False),
                        merge_lora=get_attribute_from_cfg(self.config, 'merge_lora', False),
                        save_lora_config_adapter=get_attribute_from_cfg(self.config, 'save_lora_config_adapter', True),
                        merge_sharded_lora=get_attribute_from_cfg(self.config, 'merge_sharded_lora', False),
                    )
                    self.nxd_config["lora_config"] = lora_config
        
                if self.config.precision.type == "fp32":
                    config.reduce_dtype = get_dtype(self.config.precision.get('parallel_layers_reduce_dtype', 'fp32')) # RS would be in fp32 as there is no implicit downcasting
                    config.torch_dtype = torch.float32
                else:
                    config.reduce_dtype = torch.bfloat16 # default RS type, this wont get downcasted to anything else, so RS will happen at bf16
                    if get_dtype(self.config.precision.get('parallel_layers_reduce_dtype', 'bf16')) == torch.float32:
                        config.reduce_dtype = torch.float64
                    config.torch_dtype = torch.bfloat16
                   
                leaf_module_cls = [LlamaRMSNorm.__name__, LlamaRotaryEmbedding.__name__]
                activation_recompute_modules = []
                recompute_modules = self.config.model.get("activations_checkpoint_recompute", [])
                granularity = self.config.model.get("activations_checkpoint_granularity", None)
        
                if granularity == "selective":
                    for module in recompute_modules:
                        module_obj = getattr(sys.modules[__name__], module, None)
                        if module_obj is not None:
                            activation_recompute_modules.append(module_obj)
                elif granularity == "full":
                    activation_recompute_modules = "full"
                elif not self.config.model.fusions.get("flash_attention", False):
                    activation_recompute_modules.append(CoreAttention) # do CoreAttention checkpointing if flash_attention is off
                else:
                    activation_recompute_modules = None
        
                self.nxd_config["activation_checkpoint_config"] = activation_recompute_modules
                self.nxd_config["pipeline_config"].update(
                    {
                        "transformer_layer_cls": LlamaDecoderLayer,
                        "output_loss_value_spec": (True, False),
                        "input_names": ["input_ids", "attention_mask", "labels"],
                        "leaf_module_cls": leaf_module_cls,
                    }
                )
                include_buffers = True
                return nxd.initialize_parallel_model(self.nxd_config, self.model_provider_func, include_buffers, config)
        
            def model_provider_func(self, config):
                model = LlamaForCausalLM(config)
                # Here we make sure we use the same sine and cosine matrices for all layers.
                # Making use of same tensors would make the CSE algorithm eliminate the lookup call
                # from layers, keeping only lookup from first layer.
                # with torch.no_grad():
                #     cos, sin = self.get_sin_cos_matrix(config)
                #     for layer in model.model.layers:
                #         layer.self_attn.rotary_emb.cos_cached = cos
                #         layer.self_attn.rotary_emb.sin_cached = sin
        
                if os.environ.get("XLA_DOWNCAST_BF16", None) == "0" and config.torch_dtype == torch.bfloat16:
                    model = model.to(torch.bfloat16)
                    
                return model
        
            def get_sin_cos_matrix(self, config):
                head_dim = config.hidden_size // config.num_attention_heads
                base = config.rope_theta
                inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
                t = torch.arange(config.max_position_embeddings, dtype=inv_freq.dtype)
                freqs = torch.einsum("i,j->ij", t, inv_freq)
                # Different from paper, but it uses a different permutation in order to obtain the same calculation
                emb = torch.cat((freqs, freqs), dim=-1)
                return emb.cos()[None, None, :, :].to(torch.float32), emb.sin()[None, None, :, :].to(torch.float32)
        
            def init_weights(self, module, device):
                """
                Re-init weights after partition
                Referred from HF transformers https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L690
                """
                # Last else should always call super().init_weights() to allow initializing
                # pre-defined layers.
                for key, nested_module in module._modules.items():
                    if isinstance(nested_module, LlamaRotaryEmbedding):
                        module._modules[key] = LlamaRotaryEmbedding(nested_module.config, device)
                if isinstance(module, LlamaRMSNorm):
                    module.weight.data.fill_(1.0)
                else:
                    super().init_weights(module, device)
        
        
  4. Run the AOT compilation sbatch command. However, set it to not load a checkpoint and use full-scratch weights (i.e., resume_from_checkpoint: null).

    After dataset preprocessing loop finishes, the process is likely to immediately OOM and die. Without parallelization the model will not fit on device memory. To enable parallelism you must modify the model definition as described below.

  5. Enumerate all the “parameters” included in the model and check the size of parameters in each layer.

    Explanation
    • For example, you can check like this:

      import torch
      from transformers import AutoModel
      
      model = AutoModel.from_pretrained("/fsx/models/Tanuki-8B-dpo-v1.0/")
      
      def print_param_shapes(model: torch.nn.Module):
          total_params = 0
          for name, param in model.named_parameters():
              print(f"{name:<60} {tuple(param.shape)}")
              total_params += param.numel()
          print(f"\nTotal parameters: {total_params:,}")
      
      # Run
      print_param_shapes(model)
      
    • Summarizing the results yields something like:

      model.embed_tokens.weight                        (65024, 4096)
      model.layers.*.self_attn.q_proj.weight           (4096, 4096)
      model.layers.*.self_attn.k_proj.weight           (1024, 4096)
      model.layers.*.self_attn.v_proj.weight           (1024, 4096)
      model.layers.*.self_attn.o_proj.weight           (4096, 4096)
      model.layers.*.mlp.gate_proj.weight              (14336, 4096)
      model.layers.*.mlp.up_proj.weight                (14336, 4096)
      model.layers.*.mlp.down_proj.weight              (4096, 14336)
      model.layers.*.input_layernorm.weight            (4096,)
      model.layers.*.post_attention_layernorm.weight   (4096,)
      model.norm.weight                                (4096,)
      lm_head.weight                                   (65024, 4096)
      

    Among the above, large parameter blocks such as embed_tokens, self_attn.[qkvo]_proj, mlp.(gate|up|down)_proj, and lm_head must be converted into tensor-parallelizable forms.

  6. First, make the token-ID→vector embedding layer tensor-parallelizable by using a ParallelEmbedding layer.

    • In the HF source, LlamaModel.__init__ instantiates self.embed_tokens as an nn.Embedding. Replace this with ParallelEmbedding. Override LlamaModel.__init__ in modeling_my_llama.py to apply the following change (also be aware that LlamaForCausalLM.__init__, which instantiates LlamaModel, needs to be overridden so it constructs the new LlamaModel; correspondingly adjust super().__init__ calls at the top of each __init__ as required):

      • Before:

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        
      • After (requires from neuronx_distributed.parallel_layers.layers import ParallelEmbedding):

        self.embed_tokens = ParallelEmbedding(config.vocab_size, config.hidden_size, self.padding_idx, sequence_parallel_enabled=self.config.sequence_parallel_enabled)
        
      • Explanation

        • ParallelEmbedding (official Docs) (source) has the same role as torch.nn.Embedding, but the parameter tensor is split across the vocabulary dimension among TP devices.

        • sequence_parallel_enabled toggles sequence parallel usage:

          • False (default): sequence parallel is not used. Each TP device will return the full original-shape tensor.
          • True: sequence parallel is used. Outputs are returned in sequence-parallel mode (each device has the TP-split tensor; sequence dimension and batch dimension axis ordering may be different — see docs).
  7. Make the MLP layer LlamaMLP tensor-parallelizable using ColumnParallelLinear and RowParallelLinear.

    • In HF, LlamaMLP.__init__ instantiates self.gate_proj, self.up_proj, and self.down_proj as nn.Linear. Replace them and override LlamaMLP.__init__ accordingly (note: as before, LlamaDecoderLayer.__init__ and LlamaForCausalLM.__init__ that call these constructors need corresponding overrides):

      • Before:

        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
        
      • After (requires from neuronx_distributed.parallel_layers.layers import ColumnParallelLinear, RowParallelLinear):

        self.gate_proj = ColumnParallelLinear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, gather_output=False, sequence_parallel_enabled=self.config.sequence_parallel_enabled)
        self.up_proj = ColumnParallelLinear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, gather_output=False, sequence_parallel_enabled=self.config.sequence_parallel_enabled)
        self.down_proj = RowParallelLinear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias, input_is_parallel=True, sequence_parallel_enabled=self.config.sequence_parallel_enabled)
        
      • Explanation

        • ColumnParallelLinear (Docs) (source) behaves like torch.nn.Linear but splits the parameter across the output dimension among TP devices.

          • gather_output controls whether the output is gathered across TP devices:

            • False: output remains sharded (each device holds its TP slice).
            • True (default): outputs are gathered and each device gets full output shape.
          • sequence_parallel_enabled toggles sequence parallel behavior as before.

        • RowParallelLinear (Docs) (source) splits parameters across the input dimension among TP devices.

          • input_is_parallel indicates whether the input is already sharded across TP devices.
        • When two linear layers are chained, the common efficient pattern is to make the first ColumnParallelLinear(gather_output=False) and the next RowParallelLinear(input_is_parallel=True).

  8. Convert the final fully connected lm_head to a tensor-parallel form.

    • In HF, LlamaForCausalLM.__init__ instantiates self.lm_head as nn.Linear. Replace it with ColumnParallelLinear:

      • Before:

        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        
      • After:

        self.lm_head = ColumnParallelLinear(config.hidden_size, config.vocab_size, bias=False, gather_output=False, sequence_parallel_enabled=self.config.sequence_parallel_enabled)
        
    • In this design, the TP merge point is just before cross-entropy calculation. NxD provides parallel_cross_entropy to compute softmax cross-entropy after TP-splitting. Modify the loss computation at the end of LlamaForCausalLM.forward as follows:

      • Before:

        loss = None
        if labels is not None:
            loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
        
      • After (requires from neuronx_distributed.parallel_layers.loss_functions import parallel_cross_entropy):

        if self.config.sequence_parallel_enabled:
            logits = logits.transpose(0, 1).contiguous()
        
        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].clone().contiguous()
            shift_labels = labels[..., 1:].contiguous()
            shift_logits = shift_logits.view(-1, shift_logits.size(-1))
            shift_labels = shift_labels.view(-1)
            shift_labels = shift_labels.to(shift_logits.device)
            loss = parallel_cross_entropy(shift_logits, shift_labels)
            loss = torch.mean(loss)
        
      • Explanation

        • parallel_cross_entropy (Docs) (source) computes cross-entropy from TP-sharded logits and (non-sharded) labels.
        • Expected shapes: logits.shape == (batch, seq_len, vocab_size/TP) and labels.shape == (batch, seq_len) (or flattened equivalents).
        • If sequence parallel is enabled, batch and sequence dims may be transposed, hence the initial transpose.
        • The index-shifting logic is explicitly added here because in the HF version it was handled inside self.loss_function.
  9. Make the self-attention LlamaAttention tensor-parallelizable.

    • Background on LlamaAttention

      • We assume basic self-attention knowledge; here are features specific to LlamaAttention:

        • It supports Grouped Query Attention (GQA). GQA reduces compute/memory by sharing Key/Value heads among Query heads. For example, with 32 Q heads and 8 KV heads, many Q heads share the same K/V, reducing KV cache cost while retaining Q expressiveness.
    • LlamaAttention can be divided into four major parts:

      1. Computing Q, K, V from input (hidden_states)
      2. Applying RoPE to Q, K
      3. Computing O from Q, K, V
      4. Computing the output (next hidden states) from O

      There are many required changes — explained step by step below.

      • (i) Compute Q, K, V from input

        • Replace the three linear layers for Q/K/V with GQAQKVColumnParallelLinear.

          • In __init__:

            • Before

              self.q_proj = nn.Linear(
                  config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
              )
              self.k_proj = nn.Linear(
                  config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
              )
              self.v_proj = nn.Linear(
                  config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
              )
              
            • After (requires from neuronx_distributed.modules.qkv_linear import GQAQKVColumnParallelLinear):

              self.qkv_proj = GQAQKVColumnParallelLinear(
                  config.hidden_size,
                  [config.num_attention_heads * self.head_dim, config.num_key_value_heads * self.head_dim],
                  bias=config.attention_bias,
                  gather_output=False,
                  kv_size_multiplier=self.config.kv_shared_group_size,
                  fuse_qkv=self.config.fuse_qkv,
                  sequence_parallel_enabled=self.config.sequence_parallel_enabled
              )
              
            • Explanation

              • GQAQKVColumnParallelLinear (Docs) (source) computes Q/K/V via linear transforms and splits the parameters across the output dimension for TP.
              • The second argument is [q_proj_output_dim, k_proj(v_proj)_output_dim].
              • It returns (query_states, key_states, value_states) with axis order [batch, seq_len, head_count*head_dim].
              • kv_size_multiplier corresponds to KV_REPLICATOR (how many copies of KV weights to keep).
              • fuse_qkv=True stores Q/K/V parameters fused (names like (weight|bias)_qkv) which is more efficient.
          • In __init__, variable definitions change:

            • Before

              self.num_heads = config.num_attention_heads
              self.num_key_value_heads = config.num_key_value_heads
              self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
              
            • After (requires import neuronx_distributed.parallel_layers.utils as neuronx_dist_utils)

              self.num_heads = neuronx_dist_utils.divide(config.num_attention_heads, get_tensor_model_parallel_size())
              self.num_key_value_heads = neuronx_dist_utils.divide(
                  config.num_key_value_heads * self.config.kv_shared_group_size, get_tensor_model_parallel_size()
              )
              self.num_key_value_groups = self.num_heads // self.num_key_value_heads
              
            • Explanation

              • self.num_(heads|key_value_heads|key_value_groups) are adjusted to be the values after dividing by TP size. These values are used in forward.
              • neuronx_dist_utils.divide(x, y) behaves like integer division but throws if x % y != 0.
          • In forward:

            • Before

              query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
              key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
              value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
              
            • After

              bsz, q_len, _ = hidden_states.size()
              if self.config.sequence_parallel_enabled:
                  q_len, bsz, _ = hidden_states.size()
                  q_len = q_len * get_tensor_model_parallel_size()
              
              query_states, key_states, value_states = self.qkv_proj(hidden_states)
              query_states, key_states, value_states, seq_len_dim_index = self.permute_qkv_for_attn(
                  query_states, key_states, value_states, bsz, q_len, self.num_heads,
                  self.num_key_value_heads, self.head_dim, self.config
              )
              

              Additionally define helper functions inside LlamaAttention:

                def reshape_and_permute_states_for_fa(self, states, bsz, q_len, num_heads, head_dim, use_sequence_parallel):
                    if use_sequence_parallel:
                        return states.view(q_len, bsz, num_heads, head_dim).permute(1, 2, 3, 0)
                    else:
                        return states.view(bsz, q_len, num_heads, head_dim).permute(0, 2, 3, 1)
              
                def permute_qkv_for_attn(
                        self, query_states, key_states, value_states, bsz, q_len, num_heads, num_key_value_heads, head_dim, config
                    ):
              
                    if config.transpose_nki_inputs and config.use_flash_attention:
                        query_states = self.reshape_and_permute_states_for_fa(query_states, bsz, q_len, num_heads, head_dim, config.sequence_parallel_enabled)
                        key_states = self.reshape_and_permute_states_for_fa(key_states, bsz, q_len, num_key_value_heads, head_dim, config.sequence_parallel_enabled)
                        value_states = self.reshape_and_permute_states_for_fa(value_states, bsz, q_len, num_key_value_heads, head_dim, config.sequence_parallel_enabled)
                        dim_index = -1
                    elif config.sequence_parallel_enabled:
                        query_states = query_states.view(q_len, bsz, num_heads, head_dim).permute(1, 2, 0, 3)
                        key_states = key_states.view(q_len, bsz, num_key_value_heads, head_dim).permute(1, 2, 0, 3)
                        value_states = value_states.view(q_len, bsz, num_key_value_heads, head_dim).permute(1, 2, 0, 3)
                        dim_index = -2
                    else:
                        query_states = query_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
                        key_states = key_states.view(bsz, q_len, num_key_value_heads, head_dim).transpose(1, 2)
                        value_states = value_states.view(bsz, q_len, num_key_value_heads, head_dim).transpose(1, 2)
                        dim_index = -2
                    
                    return query_states, key_states, value_states, dim_index
              
            • Explanation

              • Reorders the Q/K/V outputs from GQAQKVColumnParallelLinear into the layout convenient for downstream attention computation. Axis ordering depends on flags like transpose_nki_inputs and use_flash_attention.
      • (ii) Apply RoPE to Q, K

        • Imports:

          • Before

            from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
            
          • After

            from neuronx_distributed.overrides.transformer_overrides import apply_rotary_pos_emb
            
        • In forward:

          • Before

            query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
            
          • After

            query_states, key_states = apply_rotary_pos_emb(
                query_states, key_states, cos, sin, None, 
                self.config.use_flash_attention, self.config.transpose_nki_inputs
            )
            
        • Explanation

          • The NxD version of apply_rotary_pos_emb supports cases where use_flash_attention==True and transpose_nki_inputs==True (i.e., different axis order), otherwise it behaves the same as HF.
      • (iii) Compute O from Q, K, V

        • First, reflect the KV_REPLICATOR by repeating K and V so their head counts match Q heads:

          • Add:

            # repeat k/v heads if n_kv_heads < n_heads
            key_states = repeat_kv(key_states, self.num_key_value_groups)
            value_states = repeat_kv(value_states, self.num_key_value_groups)
            
          • Explanation

            • repeat_kv expands a tensor shaped (batch, n_kv_heads, seq_len, head_dim) to (batch, n_kv_heads * replicates, seq_len, head_dim) by repeating groups accordingly.
            • (In HF this replication was done inside eager_attention_forward.)
        • Now the core Q,K,V → O computation. On GPUs this typically uses Flash Attention for memory efficiency; for Neuron chips the equivalent implementation is nki_flash_attn_func.

          • Before

            attention_interface: Callable = eager_attention_forward
            if self.config._attn_implementation != "eager":
                if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
                    logger.warning_once(
                        "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
                        'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
                    )
                else:
                    attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
            
            attn_output, attn_weights = attention_interface(
                self,
                query_states,
                key_states,
                value_states,
                attention_mask,
                dropout=0.0 if not self.training else self.attention_dropout,
                scaling=self.scaling,
                **kwargs,
            )
            
          • After (requires from neuronx_distributed.kernels.flash_attn import nki_flash_attn_func)

            attn_output = nki_flash_attn_func(query_states, key_states, value_states, self.config.lnc, transpose_nki_inputs=self.config.transpose_nki_inputs)
            if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
                raise ValueError(
                    f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                    f" {attn_output.size()}"
                )
            
          • Explanation

            • nki_flash_attn_func (source) is a Neuron kernel implementation of Flash Attention using the Neuron Kernel Interface (NKI).

            • Input axis order must match transpose_nki_inputs:

              • If transpose_nki_inputs==True: [batch, heads, head_dim, seq_len]
              • Else: [batch, heads, seq_len, head_dim]
            • attn_output always comes out as [batch, heads, seq_len, head_dim].

      • (iv) Compute output from O

        • In __init__:

          • Before

            self.o_proj = nn.Linear(
                config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
            )
            
          • After

            self.o_proj = RowParallelLinear(
                config.num_attention_heads * self.head_dim,
                config.hidden_size,
                bias=config.attention_bias,
                input_is_parallel=True,
                sequence_parallel_enabled=self.config.sequence_parallel_enabled
            )
            
        • In forward, just before passing through o_proj, reshape:

          • Before

            attn_output = attn_output.reshape(*input_shape, -1).contiguous()
            
          • After

            if self.config.sequence_parallel_enabled:
                attn_output = attn_output.permute(2, 0, 1, 3)
                attn_output = attn_output.reshape(q_len, bsz, self.hidden_size // get_tensor_model_parallel_size())
            else:
                attn_output = attn_output.transpose(1, 2).contiguous()
                attn_output = attn_output.reshape(bsz, q_len, self.hidden_size // get_tensor_model_parallel_size())
            
          • Explanation

            • Reorder and reshape the nki_flash_attn_func output into the usual activation axis order (if sequence parallel: [seq_len, batch, head_dim*heads], else [batch, seq_len, head_dim*heads]). Note that the hidden dimension is still TP-split at this point; o_proj will merge it.
  10. Re-run the AOT compilation.

    • If errors occur, debug. General tips:

      • Debugging tips

        • Logs can be thousands of lines and interleaved from parallel processes; to find root cause:

          • Search for the earliest occurrence of the string error. If unclear, search for Error: (Python errors) or Error | (Neuron compiler errors).
          • Let an LLM read narrowed log excerpts to analyze root cause — reduce the log range first then feed to the LLM.
        • Use xm.master_print instead of print (requires import torch_xla.core.xla_model as xm) so only the master process prints and you avoid duplicated logs across processes.

        • Locating the problematic layer

          • Compiler errors do not always identify the exact layer. A useful technique is layer short-circuiting: temporarily skip executing an entire LlamaDecoderLayer loop; if the error disappears, the issue is inside that layer; if not, the issue may be earlier (embedding) or later (final linear or loss).
        • Inspecting tensors safely

          • Inserting print may trigger lazy compilation differently and change behavior. Inspecting .shape, .dtype, and .device is usually safe without triggering problematic recompile behavior.
      • Common compiler error messages

        • Estimated peak HBM usage (18.179819) exceeds 16GB. Neff won't be able to load on chip

          • Model does not fit on each Neuron core. Reduce memory (seq length / batch / model size) or increase model parallelism (TP/PP).
        • Couldn't color the DRAM even with 100GB of DRAM space assumption, model needs too much HBM memory !

          • Same: model too big to fit.
        • DRAM usage for Internal DRAM tensor exceeds 16GB of device space limit, cannot fit into device, model requires too much HBM memory !

          • Same.
        • Internal tensorizer error: VectorizeDMA:Illegal after shrink dst!

          • Indicates a tensor shape that varies per operation; some shapes must be static per operator.
        • RuntimeError: (1, 32768, 1, 80) and (1, 32768, 1, 80)

          • Shape mismatch during elementwise ops; note sometimes the message incorrectly repeats the same shape twice.
        • CCOM WARN No transport found between devices 8 and 7. Possible replica group misconfiguration

          • Inter-core communication issue — certain TP/PP/KV_REPLICATOR combos may cause no full connectivity.
        • ERROR: Unsupported operation: mhlo.set_dimension_size

        • Number of instructions (8146371) is over the threshold (5000000). - Compile under --optlevel=1 to create smaller subgraphs or use pipeline parallelism.

          • Graph too large; use lower opt level or pipeline parallelism.
        • RuntimeError: Bad StatusOr access: RESOURCE_EXHAUSTED: LoadCollectives: error condition NRT_RESOURCE == rt_status:

          • Updating Neuron SDK can sometimes resolve this.
  11. Re-run the actual training (“train”).

    Finally, enable checkpoint loading support.

  12. Revert to checkpoint loading enabled and re-run AOT compilation and training again.

    • You will likely see the following Python error when loading checkpoints:

      RuntimeError: Missing keys when loading state dictionary: model.layers.0.mlp.gate_proj.weight, model.layers.0.mlp.up_proj.weight, model.layers.1.mlp.gate_proj.weight, ...(omitted)... model.layers.1.mlp.up_proj.weight,model.layers.30.mlp.gate_proj.weight, model.layers.30.mlp.up_proj.weight, model.layers.31.mlp.gate_proj.weight, model.layers.31.mlp.up_proj.weight
      
    • This happens because the HF→NxD checkpoint conversion expects the MLP’s gate_proj and up_proj linear layers to be combined into a single gate_up_proj linear layer, but your ported implementation did not account for that. Therefore the checkpoint lacks gate_proj/up_proj keys while your model expects them, causing the missing keys error.

    • Fix as follows:

      • Modify LlamaMLP.__init__ from:

        • Before

          self.gate_proj = ColumnParallelLinear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, gather_output=False, sequence_parallel_enabled=self.config.sequence_parallel_enabled)
          self.up_proj = ColumnParallelLinear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, gather_output=False, sequence_parallel_enabled=self.config.sequence_parallel_enabled)
          
          self.act_fn = ACT2FN[config.hidden_act]
          
        • After

          self.gate_up_proj = ColumnParallelLinear(
              self.hidden_size,
              2 * self.intermediate_size,
              stride=2,
              bias=config.mlp_bias,
              gather_output=False,
              sequence_parallel_enabled=self.config.sequence_parallel_enabled,
          )
          
          self.activation_multiply = ActivationMultiplyMLP(config)
          
        • Explanation

          • The stride parameter of ColumnParallelLinear (default: 1) changes how the parameter tensor is TP-partitioned. With stride=2, each device’s stored slice contains two adjacent blocks: the first half corresponds to gate_proj and the second half corresponds to up_proj. This matches the HF→NxD checkpoint packing.
      • Also add:

        class ActivationMultiplyMLP(torch.nn.Module):
            def __init__(self, config):
                nn.Module.__init__(self)
                self.act_fn = ACT2FN[config.hidden_act]
                self.split_size = config.intermediate_size // get_tensor_model_parallel_size()
            
            def forward(self, x):
                gate_proj, up_proj = x.split(self.split_size, dim=2)
                intermediate_states = self.act_fn(gate_proj) * up_proj
                return intermediate_states
        
      • Modify LlamaMLP.forward:

        • Before

          down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
          
        • After

          intermediate_states = self.activation_multiply(self.gate_up_proj(x))
          down_proj = self.down_proj(intermediate_states)
          
  13. Finally, to verify the port is correct, fix a test input sequence and pass it through the forward of the pre-port (original HF) model and the post-port model and confirm that logits are the same (in eval mode).

    • Note numerical differences are expected; exact equality is not guaranteed. However, if logits are obviously different, some layer porting is incorrect — inspect intermediate hidden_states per layer and debug.

If you get through all the steps above, you should have covered the core practices for porting a model.

脚注
  1. We would like to thank Mr. Tokoyo from AWS for his supervision of this material. ↩︎

KARAKURI Techblog

Discussion