AWS Trainium 50 Exercises #6: Re-porting Llama 3 to Trainium
Chapter 6 — Re-porting Llama 3 to Trainium
This chapter assumes the following:
- You have completed Chapter 4 “Training NxD-ready models”.
- You have the basic knowledge of distributed training (contents of Chapter 5).
Problems (38–50)
In the previous-previous chapter we trained tanuki-8b. The model architecture of tanuki-8b is
LlamaForCausalLMas shown in itsconfig.json. This architecture is defined in themodeling_llama.pyin the transformers library.On the other hand, there is also a
modeling_llama.pyin the NxDT (NeuronX Distributed Training) library. That file is the version “ported” from the transformers implementation so that the model can be trained in a distributed way on Neuron chips.New model architectures appear every day. The models directory of the transformers library contains many architecture implementations in addition to
LlamaForCausalLM. However, most of those implementations do not yet have official NxDT-ported versions.In this chapter we do have an official NxDT
modeling_llama.py, but assume the case where such a port is not provided and practice how to port a model yourself.
If you inspect the NxDT
modeling_llama.py, you will see that basically it inherits the layer classes defined in the original (hereafter “HF version”)modeling_llama.pyand overrides only the necessary parts. ⚠️Note: Because of this inheritance/override structure, the NxDT library behavior is highly sensitive to the transformers library version. When using NxDT, we strongly recommend pinning the transformers version strictly.
-
Study the HF version of
modeling_llama.pythoroughly. (Note: HF sources are updated frequently — check the source code in the transformers package installed in your current Python environment rather than relying on GitHub.) Precisely understand which layer calls which layer and in what order — the full tree structure. It is recommended to feed the entiremodeling_llama.pysource as context to an LLM such as Claude, Gemini, or ChatGPT, and ask something like: “Can you draw a simple textual tree diagram showing which layer calls which layer and the overall flow of computation?”-
Example answer from an LLM:
LlamaForCausalLM ├── LlamaModel │ ├── embed_tokens (embedding layer) │ ├── Rotary Position Embedding (RoPE) │ ├── LlamaDecoderLayer (decoder layer) × N │ │ ├── LlamaRMSNorm (normalization) │ │ ├── LlamaAttention (self-attention) │ │ │ ├── query, key, value computation │ │ │ ├── apply Rotary Position Embedding │ │ │ ├── attention weight computation (Scaled Dot-Product Attention) │ │ │ └── attention output computation │ │ ├── residual connection │ │ ├── LlamaRMSNorm (normalization) │ │ ├── LlamaMLP (feed-forward) │ │ │ ├── gate_proj & up_proj (dim transforms) │ │ │ ├── activation function │ │ │ ├── down_proj (dim reduction) │ │ │ └── output │ │ └── residual connection │ ├── LlamaRMSNorm (normalization) │ └── return hidden states └── lm_head (logit computation)
-
-
Create a new
modelsdirectory under thetrain.shhierarchy and inside it create a file namedmodeling_my_llama.py. You will develop the porting code here. Start by inheriting the HF version classes and making them identical (no overrides yet). Use the following content as a starting point:from transformers.models.llama.modeling_llama import ( LlamaForCausalLM as LlamaForCausalLMHF, LlamaRotaryEmbedding as LlamaRotaryEmbeddingHF, LlamaDecoderLayer as LlamaDecoderLayerHF, LlamaAttention as LlamaAttentionHF, LlamaRMSNorm as LlamaRMSNormHF, LlamaMLP as LlamaMLPHF ) class LlamaForCausalLM(LlamaForCausalLMHF): pass class LlamaRotaryEmbedding(LlamaRotaryEmbeddingHF): pass class LlamaDecoderLayer(LlamaDecoderLayerHF): pass class LlamaAttention(LlamaAttentionHF): pass class LlamaRMSNorm(LlamaRMSNormHF): pass class LlamaMLP(LlamaMLPHF): pass -
Run it as-is to see what currently fails.
-
Modify your training code so it calls the hastily created model definition above. (Reference: official Docs)
-
Modify
training.pylike this:-
Before:
from neuronx_distributed_training.lightning_modules.model.hf_models.llama_model import ( HFLLamaModule, ) ...(omitted)... model = HFLLamaModule(cfg, trainer) -
After:
from models.my_llama_model import HFMyLLamaModule ...(omitted)... model = HFMyLLamaModule(cfg, trainer)
-
-
Create a new file
my_llama_model.pyinside themodelsfolder and define theHFMyLLamaModulethat the above import expects. Example:my_llama_model.py
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 import os import neuronx_distributed as nxd import torch from transformers import LlamaConfig import sys from neuronx_distributed.utils.utils import hardware from neuronx_distributed_training.utils import get_dtype, get_attribute_from_cfg from torch_neuronx.utils import get_platform_target from models.modeling_my_llama import ( LlamaAttention as CoreAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaRMSNorm, LlamaMLP, LlamaRotaryEmbedding ) from neuronx_distributed_training.lightning_modules.model.hf_models.base_model import BaseHfModel class HFMyLLamaModule(BaseHfModel): def _get_model(self): config = LlamaConfig.from_pretrained(self.config.model.model_config) config.use_cache = False config.return_dict = False config.sequence_parallel_enabled = self.config.distributed_strategy.get("sequence_parallel", False) config.qkv_linear = self.config.model.get("qkv_linear", False) config.fuse_qkv = self.config.model.get("fuse_qkv", True) config.kv_shared_group_size = self.config.distributed_strategy.get("kv_replicator", 1) config.max_position_embeddings = self.config.model.get("max_position_embeddings", config.max_position_embeddings) config.use_flash_attention = self.config.model.fusions.flash_attention config.use_ring_attention = get_attribute_from_cfg(self.config, 'ring_attention', False) hardware_type = hardware(get_platform_target()) if hardware_type==hardware.TRN1: config.lnc = self.config.trainer.get("lnc", 1) if hardware_type==hardware.TRN2: config.lnc = self.config.trainer.get("lnc", 2) if self.config.model.get('num_layers', -1) != -1: config.num_hidden_layers = self.config.model.get('num_layers') if self.config.model.get('hidden_size', -1) != -1: config.hidden_size = self.config.model.get('hidden_size') if self.config.model.get('rope_theta', -1) != -1: config.rope_theta = self.config.model.get('rope_theta') config.head_dim = get_attribute_from_cfg(self.config, 'hidden_size', config.hidden_size) // config.num_attention_heads # overriding head_dim value, which was set in transformers code config.transpose_nki_inputs = self.config.model.get('transpose_nki_inputs', True) # transpose_nki_inputs by default if get_attribute_from_cfg(self.config, "peft", False): lora_config = nxd.modules.lora.LoraConfig( lora_rank=get_attribute_from_cfg(self.config, 'lora_rank', 16), lora_alpha=get_attribute_from_cfg(self.config, 'lora_alpha', 32), lora_dropout=get_attribute_from_cfg(self.config, 'lora_dropout', 0.05), bias=get_attribute_from_cfg(self.config, 'lora_bias', "none"), lora_verbose=get_attribute_from_cfg(self.config, 'lora_verbose', True), target_modules=get_attribute_from_cfg(self.config, 'target_modules', ["qkv_proj"]), load_lora_from_ckpt=get_attribute_from_cfg(self.config, 'load_lora_from_ckpt', False), save_lora_base=get_attribute_from_cfg(self.config, 'save_lora_base', False), merge_lora=get_attribute_from_cfg(self.config, 'merge_lora', False), save_lora_config_adapter=get_attribute_from_cfg(self.config, 'save_lora_config_adapter', True), merge_sharded_lora=get_attribute_from_cfg(self.config, 'merge_sharded_lora', False), ) self.nxd_config["lora_config"] = lora_config if self.config.precision.type == "fp32": config.reduce_dtype = get_dtype(self.config.precision.get('parallel_layers_reduce_dtype', 'fp32')) # RS would be in fp32 as there is no implicit downcasting config.torch_dtype = torch.float32 else: config.reduce_dtype = torch.bfloat16 # default RS type, this wont get downcasted to anything else, so RS will happen at bf16 if get_dtype(self.config.precision.get('parallel_layers_reduce_dtype', 'bf16')) == torch.float32: config.reduce_dtype = torch.float64 config.torch_dtype = torch.bfloat16 leaf_module_cls = [LlamaRMSNorm.__name__, LlamaRotaryEmbedding.__name__] activation_recompute_modules = [] recompute_modules = self.config.model.get("activations_checkpoint_recompute", []) granularity = self.config.model.get("activations_checkpoint_granularity", None) if granularity == "selective": for module in recompute_modules: module_obj = getattr(sys.modules[__name__], module, None) if module_obj is not None: activation_recompute_modules.append(module_obj) elif granularity == "full": activation_recompute_modules = "full" elif not self.config.model.fusions.get("flash_attention", False): activation_recompute_modules.append(CoreAttention) # do CoreAttention checkpointing if flash_attention is off else: activation_recompute_modules = None self.nxd_config["activation_checkpoint_config"] = activation_recompute_modules self.nxd_config["pipeline_config"].update( { "transformer_layer_cls": LlamaDecoderLayer, "output_loss_value_spec": (True, False), "input_names": ["input_ids", "attention_mask", "labels"], "leaf_module_cls": leaf_module_cls, } ) include_buffers = True return nxd.initialize_parallel_model(self.nxd_config, self.model_provider_func, include_buffers, config) def model_provider_func(self, config): model = LlamaForCausalLM(config) # Here we make sure we use the same sine and cosine matrices for all layers. # Making use of same tensors would make the CSE algorithm eliminate the lookup call # from layers, keeping only lookup from first layer. # with torch.no_grad(): # cos, sin = self.get_sin_cos_matrix(config) # for layer in model.model.layers: # layer.self_attn.rotary_emb.cos_cached = cos # layer.self_attn.rotary_emb.sin_cached = sin if os.environ.get("XLA_DOWNCAST_BF16", None) == "0" and config.torch_dtype == torch.bfloat16: model = model.to(torch.bfloat16) return model def get_sin_cos_matrix(self, config): head_dim = config.hidden_size // config.num_attention_heads base = config.rope_theta inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) t = torch.arange(config.max_position_embeddings, dtype=inv_freq.dtype) freqs = torch.einsum("i,j->ij", t, inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) return emb.cos()[None, None, :, :].to(torch.float32), emb.sin()[None, None, :, :].to(torch.float32) def init_weights(self, module, device): """ Re-init weights after partition Referred from HF transformers https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L690 """ # Last else should always call super().init_weights() to allow initializing # pre-defined layers. for key, nested_module in module._modules.items(): if isinstance(nested_module, LlamaRotaryEmbedding): module._modules[key] = LlamaRotaryEmbedding(nested_module.config, device) if isinstance(module, LlamaRMSNorm): module.weight.data.fill_(1.0) else: super().init_weights(module, device)
-
-
-
Run the AOT compilation
sbatchcommand. However, set it to not load a checkpoint and use full-scratch weights (i.e.,resume_from_checkpoint: null).After dataset preprocessing loop finishes, the process is likely to immediately OOM and die. Without parallelization the model will not fit on device memory. To enable parallelism you must modify the model definition as described below.
-
Enumerate all the “parameters” included in the model and check the size of parameters in each layer.
Explanation
-
For example, you can check like this:
import torch from transformers import AutoModel model = AutoModel.from_pretrained("/fsx/models/Tanuki-8B-dpo-v1.0/") def print_param_shapes(model: torch.nn.Module): total_params = 0 for name, param in model.named_parameters(): print(f"{name:<60} {tuple(param.shape)}") total_params += param.numel() print(f"\nTotal parameters: {total_params:,}") # Run print_param_shapes(model) -
Summarizing the results yields something like:
model.embed_tokens.weight (65024, 4096) model.layers.*.self_attn.q_proj.weight (4096, 4096) model.layers.*.self_attn.k_proj.weight (1024, 4096) model.layers.*.self_attn.v_proj.weight (1024, 4096) model.layers.*.self_attn.o_proj.weight (4096, 4096) model.layers.*.mlp.gate_proj.weight (14336, 4096) model.layers.*.mlp.up_proj.weight (14336, 4096) model.layers.*.mlp.down_proj.weight (4096, 14336) model.layers.*.input_layernorm.weight (4096,) model.layers.*.post_attention_layernorm.weight (4096,) model.norm.weight (4096,) lm_head.weight (65024, 4096)
Among the above, large parameter blocks such as
embed_tokens,self_attn.[qkvo]_proj,mlp.(gate|up|down)_proj, andlm_headmust be converted into tensor-parallelizable forms. -
-
First, make the token-ID→vector embedding layer tensor-parallelizable by using a
ParallelEmbeddinglayer.-
In the HF source,
LlamaModel.__init__instantiatesself.embed_tokensas annn.Embedding. Replace this withParallelEmbedding. OverrideLlamaModel.__init__inmodeling_my_llama.pyto apply the following change (also be aware thatLlamaForCausalLM.__init__, which instantiatesLlamaModel, needs to be overridden so it constructs the newLlamaModel; correspondingly adjustsuper().__init__calls at the top of each__init__as required):-
Before:
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) -
After (requires
from neuronx_distributed.parallel_layers.layers import ParallelEmbedding):self.embed_tokens = ParallelEmbedding(config.vocab_size, config.hidden_size, self.padding_idx, sequence_parallel_enabled=self.config.sequence_parallel_enabled) -
Explanation
-
ParallelEmbedding(official Docs) (source) has the same role astorch.nn.Embedding, but the parameter tensor is split across the vocabulary dimension among TP devices. -
sequence_parallel_enabledtoggles sequence parallel usage:-
False(default): sequence parallel is not used. Each TP device will return the full original-shape tensor. -
True: sequence parallel is used. Outputs are returned in sequence-parallel mode (each device has the TP-split tensor; sequence dimension and batch dimension axis ordering may be different — see docs).
-
-
-
-
-
Make the MLP layer
LlamaMLPtensor-parallelizable usingColumnParallelLinearandRowParallelLinear.-
In HF,
LlamaMLP.__init__instantiatesself.gate_proj,self.up_proj, andself.down_projasnn.Linear. Replace them and overrideLlamaMLP.__init__accordingly (note: as before,LlamaDecoderLayer.__init__andLlamaForCausalLM.__init__that call these constructors need corresponding overrides):-
Before:
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) -
After (requires
from neuronx_distributed.parallel_layers.layers import ColumnParallelLinear, RowParallelLinear):self.gate_proj = ColumnParallelLinear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, gather_output=False, sequence_parallel_enabled=self.config.sequence_parallel_enabled) self.up_proj = ColumnParallelLinear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, gather_output=False, sequence_parallel_enabled=self.config.sequence_parallel_enabled) self.down_proj = RowParallelLinear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias, input_is_parallel=True, sequence_parallel_enabled=self.config.sequence_parallel_enabled) -
Explanation
-
ColumnParallelLinear(Docs) (source) behaves liketorch.nn.Linearbut splits the parameter across the output dimension among TP devices.-
gather_outputcontrols whether the output is gathered across TP devices:-
False: output remains sharded (each device holds its TP slice). -
True(default): outputs are gathered and each device gets full output shape.
-
-
sequence_parallel_enabledtoggles sequence parallel behavior as before.
-
-
RowParallelLinear(Docs) (source) splits parameters across the input dimension among TP devices.-
input_is_parallelindicates whether the input is already sharded across TP devices.
-
-
When two linear layers are chained, the common efficient pattern is to make the first
ColumnParallelLinear(gather_output=False)and the nextRowParallelLinear(input_is_parallel=True).
-
-
-
-
Convert the final fully connected
lm_headto a tensor-parallel form.-
In HF,
LlamaForCausalLM.__init__instantiatesself.lm_headasnn.Linear. Replace it withColumnParallelLinear:-
Before:
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) -
After:
self.lm_head = ColumnParallelLinear(config.hidden_size, config.vocab_size, bias=False, gather_output=False, sequence_parallel_enabled=self.config.sequence_parallel_enabled)
-
-
In this design, the TP merge point is just before cross-entropy calculation. NxD provides
parallel_cross_entropyto compute softmax cross-entropy after TP-splitting. Modify the loss computation at the end ofLlamaForCausalLM.forwardas follows:-
Before:
loss = None if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) -
After (requires
from neuronx_distributed.parallel_layers.loss_functions import parallel_cross_entropy):if self.config.sequence_parallel_enabled: logits = logits.transpose(0, 1).contiguous() loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].clone().contiguous() shift_labels = labels[..., 1:].contiguous() shift_logits = shift_logits.view(-1, shift_logits.size(-1)) shift_labels = shift_labels.view(-1) shift_labels = shift_labels.to(shift_logits.device) loss = parallel_cross_entropy(shift_logits, shift_labels) loss = torch.mean(loss) -
Explanation
-
parallel_cross_entropy(Docs) (source) computes cross-entropy from TP-sharded logits and (non-sharded) labels. - Expected shapes:
logits.shape == (batch, seq_len, vocab_size/TP)andlabels.shape == (batch, seq_len)(or flattened equivalents). - If sequence parallel is enabled, batch and sequence dims may be transposed, hence the initial
transpose. - The index-shifting logic is explicitly added here because in the HF version it was handled inside
self.loss_function.
-
-
-
-
Make the self-attention
LlamaAttentiontensor-parallelizable.-
Background on
LlamaAttention-
We assume basic self-attention knowledge; here are features specific to
LlamaAttention:- It supports Grouped Query Attention (GQA). GQA reduces compute/memory by sharing Key/Value heads among Query heads. For example, with 32 Q heads and 8 KV heads, many Q heads share the same K/V, reducing KV cache cost while retaining Q expressiveness.
-
-
LlamaAttentioncan be divided into four major parts:- Computing Q, K, V from input (
hidden_states) - Applying RoPE to Q, K
- Computing O from Q, K, V
- Computing the output (next hidden states) from O
There are many required changes — explained step by step below.
-
(i) Compute Q, K, V from input
-
Replace the three linear layers for Q/K/V with
GQAQKVColumnParallelLinear.-
In
__init__:-
Before
self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias ) self.k_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.v_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) -
After (requires
from neuronx_distributed.modules.qkv_linear import GQAQKVColumnParallelLinear):self.qkv_proj = GQAQKVColumnParallelLinear( config.hidden_size, [config.num_attention_heads * self.head_dim, config.num_key_value_heads * self.head_dim], bias=config.attention_bias, gather_output=False, kv_size_multiplier=self.config.kv_shared_group_size, fuse_qkv=self.config.fuse_qkv, sequence_parallel_enabled=self.config.sequence_parallel_enabled ) -
Explanation
-
GQAQKVColumnParallelLinear(Docs) (source) computes Q/K/V via linear transforms and splits the parameters across the output dimension for TP. - The second argument is
[q_proj_output_dim, k_proj(v_proj)_output_dim]. - It returns
(query_states, key_states, value_states)with axis order[batch, seq_len, head_count*head_dim]. -
kv_size_multipliercorresponds to KV_REPLICATOR (how many copies of KV weights to keep). -
fuse_qkv=Truestores Q/K/V parameters fused (names like(weight|bias)_qkv) which is more efficient.
-
-
-
In
__init__, variable definitions change:-
Before
self.num_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads -
After (requires
import neuronx_distributed.parallel_layers.utils as neuronx_dist_utils)self.num_heads = neuronx_dist_utils.divide(config.num_attention_heads, get_tensor_model_parallel_size()) self.num_key_value_heads = neuronx_dist_utils.divide( config.num_key_value_heads * self.config.kv_shared_group_size, get_tensor_model_parallel_size() ) self.num_key_value_groups = self.num_heads // self.num_key_value_heads -
Explanation
-
self.num_(heads|key_value_heads|key_value_groups)are adjusted to be the values after dividing by TP size. These values are used inforward. -
neuronx_dist_utils.divide(x, y)behaves like integer division but throws if x % y != 0.
-
-
-
In
forward:-
Before
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) -
After
bsz, q_len, _ = hidden_states.size() if self.config.sequence_parallel_enabled: q_len, bsz, _ = hidden_states.size() q_len = q_len * get_tensor_model_parallel_size() query_states, key_states, value_states = self.qkv_proj(hidden_states) query_states, key_states, value_states, seq_len_dim_index = self.permute_qkv_for_attn( query_states, key_states, value_states, bsz, q_len, self.num_heads, self.num_key_value_heads, self.head_dim, self.config )Additionally define helper functions inside
LlamaAttention:def reshape_and_permute_states_for_fa(self, states, bsz, q_len, num_heads, head_dim, use_sequence_parallel): if use_sequence_parallel: return states.view(q_len, bsz, num_heads, head_dim).permute(1, 2, 3, 0) else: return states.view(bsz, q_len, num_heads, head_dim).permute(0, 2, 3, 1) def permute_qkv_for_attn( self, query_states, key_states, value_states, bsz, q_len, num_heads, num_key_value_heads, head_dim, config ): if config.transpose_nki_inputs and config.use_flash_attention: query_states = self.reshape_and_permute_states_for_fa(query_states, bsz, q_len, num_heads, head_dim, config.sequence_parallel_enabled) key_states = self.reshape_and_permute_states_for_fa(key_states, bsz, q_len, num_key_value_heads, head_dim, config.sequence_parallel_enabled) value_states = self.reshape_and_permute_states_for_fa(value_states, bsz, q_len, num_key_value_heads, head_dim, config.sequence_parallel_enabled) dim_index = -1 elif config.sequence_parallel_enabled: query_states = query_states.view(q_len, bsz, num_heads, head_dim).permute(1, 2, 0, 3) key_states = key_states.view(q_len, bsz, num_key_value_heads, head_dim).permute(1, 2, 0, 3) value_states = value_states.view(q_len, bsz, num_key_value_heads, head_dim).permute(1, 2, 0, 3) dim_index = -2 else: query_states = query_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, num_key_value_heads, head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, num_key_value_heads, head_dim).transpose(1, 2) dim_index = -2 return query_states, key_states, value_states, dim_index -
Explanation
- Reorders the Q/K/V outputs from
GQAQKVColumnParallelLinearinto the layout convenient for downstream attention computation. Axis ordering depends on flags liketranspose_nki_inputsanduse_flash_attention.
- Reorders the Q/K/V outputs from
-
-
-
-
(ii) Apply RoPE to Q, K
-
Imports:
-
Before
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb -
After
from neuronx_distributed.overrides.transformer_overrides import apply_rotary_pos_emb
-
-
In
forward:-
Before
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) -
After
query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin, None, self.config.use_flash_attention, self.config.transpose_nki_inputs )
-
-
Explanation
- The NxD version of
apply_rotary_pos_embsupports cases whereuse_flash_attention==Trueandtranspose_nki_inputs==True(i.e., different axis order), otherwise it behaves the same as HF.
- The NxD version of
-
-
(iii) Compute O from Q, K, V
-
First, reflect the
KV_REPLICATORby repeating K and V so their head counts match Q heads:-
Add:
# repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) -
Explanation
-
repeat_kvexpands a tensor shaped(batch, n_kv_heads, seq_len, head_dim)to(batch, n_kv_heads * replicates, seq_len, head_dim)by repeating groups accordingly. - (In HF this replication was done inside
eager_attention_forward.)
-
-
-
Now the core Q,K,V → O computation. On GPUs this typically uses Flash Attention for memory efficiency; for Neuron chips the equivalent implementation is
nki_flash_attn_func.-
Before
attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): logger.warning_once( "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) else: attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, ) -
After (requires
from neuronx_distributed.kernels.flash_attn import nki_flash_attn_func)attn_output = nki_flash_attn_func(query_states, key_states, value_states, self.config.lnc, transpose_nki_inputs=self.config.transpose_nki_inputs) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}" ) -
Explanation
-
nki_flash_attn_func(source) is a Neuron kernel implementation of Flash Attention using the Neuron Kernel Interface (NKI). -
Input axis order must match
transpose_nki_inputs:- If
transpose_nki_inputs==True:[batch, heads, head_dim, seq_len] - Else:
[batch, heads, seq_len, head_dim]
- If
-
attn_outputalways comes out as[batch, heads, seq_len, head_dim].
-
-
-
-
(iv) Compute output from O
-
In
__init__:-
Before
self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) -
After
self.o_proj = RowParallelLinear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias, input_is_parallel=True, sequence_parallel_enabled=self.config.sequence_parallel_enabled )
-
-
In
forward, just before passing througho_proj, reshape:-
Before
attn_output = attn_output.reshape(*input_shape, -1).contiguous() -
After
if self.config.sequence_parallel_enabled: attn_output = attn_output.permute(2, 0, 1, 3) attn_output = attn_output.reshape(q_len, bsz, self.hidden_size // get_tensor_model_parallel_size()) else: attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size // get_tensor_model_parallel_size()) -
Explanation
- Reorder and reshape the
nki_flash_attn_funcoutput into the usual activation axis order (if sequence parallel:[seq_len, batch, head_dim*heads], else[batch, seq_len, head_dim*heads]). Note that the hidden dimension is still TP-split at this point;o_projwill merge it.
- Reorder and reshape the
-
-
- Computing Q, K, V from input (
-
-
Re-run the AOT compilation.
-
If errors occur, debug. General tips:
-
Debugging tips
-
Logs can be thousands of lines and interleaved from parallel processes; to find root cause:
- Search for the earliest occurrence of the string
error. If unclear, search forError:(Python errors) orError |(Neuron compiler errors). - Let an LLM read narrowed log excerpts to analyze root cause — reduce the log range first then feed to the LLM.
- Search for the earliest occurrence of the string
-
Use
xm.master_printinstead ofprint(requiresimport torch_xla.core.xla_model as xm) so only the master process prints and you avoid duplicated logs across processes. -
Locating the problematic layer
- Compiler errors do not always identify the exact layer. A useful technique is layer short-circuiting: temporarily skip executing an entire
LlamaDecoderLayerloop; if the error disappears, the issue is inside that layer; if not, the issue may be earlier (embedding) or later (final linear or loss).
- Compiler errors do not always identify the exact layer. A useful technique is layer short-circuiting: temporarily skip executing an entire
-
Inspecting tensors safely
- Inserting
printmay trigger lazy compilation differently and change behavior. Inspecting.shape,.dtype, and.deviceis usually safe without triggering problematic recompile behavior.
- Inserting
-
-
Common compiler error messages
-
Estimated peak HBM usage (18.179819) exceeds 16GB. Neff won't be able to load on chip- Model does not fit on each Neuron core. Reduce memory (seq length / batch / model size) or increase model parallelism (TP/PP).
-
Couldn't color the DRAM even with 100GB of DRAM space assumption, model needs too much HBM memory !- Same: model too big to fit.
-
DRAM usage for Internal DRAM tensor exceeds 16GB of device space limit, cannot fit into device, model requires too much HBM memory !- Same.
-
Internal tensorizer error: VectorizeDMA:Illegal after shrink dst!- Indicates a tensor shape that varies per operation; some shapes must be static per operator.
-
RuntimeError: (1, 32768, 1, 80) and (1, 32768, 1, 80)- Shape mismatch during elementwise ops; note sometimes the message incorrectly repeats the same shape twice.
-
CCOM WARN No transport found between devices 8 and 7. Possible replica group misconfiguration- Inter-core communication issue — certain TP/PP/KV_REPLICATOR combos may cause no full connectivity.
-
ERROR: Unsupported operation: mhlo.set_dimension_size-
Neuron compiler doesn't support all PyTorch ops; replace with supported ops (see reference).
-
-
Number of instructions (8146371) is over the threshold (5000000). - Compile under --optlevel=1 to create smaller subgraphs or use pipeline parallelism.- Graph too large; use lower opt level or pipeline parallelism.
-
RuntimeError: Bad StatusOr access: RESOURCE_EXHAUSTED: LoadCollectives: error condition NRT_RESOURCE == rt_status:- Updating Neuron SDK can sometimes resolve this.
-
-
-
-
Re-run the actual training (“train”).
Finally, enable checkpoint loading support.
-
Revert to checkpoint loading enabled and re-run AOT compilation and training again.
-
You will likely see the following Python error when loading checkpoints:
RuntimeError: Missing keys when loading state dictionary: model.layers.0.mlp.gate_proj.weight, model.layers.0.mlp.up_proj.weight, model.layers.1.mlp.gate_proj.weight, ...(omitted)... model.layers.1.mlp.up_proj.weight,model.layers.30.mlp.gate_proj.weight, model.layers.30.mlp.up_proj.weight, model.layers.31.mlp.gate_proj.weight, model.layers.31.mlp.up_proj.weight -
This happens because the HF→NxD checkpoint conversion expects the MLP’s
gate_projandup_projlinear layers to be combined into a singlegate_up_projlinear layer, but your ported implementation did not account for that. Therefore the checkpoint lacksgate_proj/up_projkeys while your model expects them, causing the missing keys error. -
Fix as follows:
-
Modify
LlamaMLP.__init__from:-
Before
self.gate_proj = ColumnParallelLinear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, gather_output=False, sequence_parallel_enabled=self.config.sequence_parallel_enabled) self.up_proj = ColumnParallelLinear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, gather_output=False, sequence_parallel_enabled=self.config.sequence_parallel_enabled) self.act_fn = ACT2FN[config.hidden_act] -
After
self.gate_up_proj = ColumnParallelLinear( self.hidden_size, 2 * self.intermediate_size, stride=2, bias=config.mlp_bias, gather_output=False, sequence_parallel_enabled=self.config.sequence_parallel_enabled, ) self.activation_multiply = ActivationMultiplyMLP(config) -
Explanation
- The
strideparameter ofColumnParallelLinear(default: 1) changes how the parameter tensor is TP-partitioned. Withstride=2, each device’s stored slice contains two adjacent blocks: the first half corresponds togate_projand the second half corresponds toup_proj. This matches the HF→NxD checkpoint packing.
- The
-
-
Also add:
class ActivationMultiplyMLP(torch.nn.Module): def __init__(self, config): nn.Module.__init__(self) self.act_fn = ACT2FN[config.hidden_act] self.split_size = config.intermediate_size // get_tensor_model_parallel_size() def forward(self, x): gate_proj, up_proj = x.split(self.split_size, dim=2) intermediate_states = self.act_fn(gate_proj) * up_proj return intermediate_states -
Modify
LlamaMLP.forward:-
Before
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) -
After
intermediate_states = self.activation_multiply(self.gate_up_proj(x)) down_proj = self.down_proj(intermediate_states)
-
-
-
-
Finally, to verify the port is correct, fix a test input sequence and pass it through the
forwardof the pre-port (original HF) model and the post-port model and confirm that logits are the same (in eval mode).- Note numerical differences are expected; exact equality is not guaranteed. However, if logits are obviously different, some layer porting is incorrect — inspect intermediate
hidden_statesper layer and debug.
- Note numerical differences are expected; exact equality is not guaranteed. However, if logits are obviously different, some layer porting is incorrect — inspect intermediate
If you get through all the steps above, you should have covered the core practices for porting a model.
-
We would like to thank Mr. Tokoyo from AWS for his supervision of this material. ↩︎
Discussion