【AWS Trainium 50本ノック #6】Llama3をTrainiumに移植しなおす編
第 6 章 Llama3をTrainiumに移植しなおす編
本章では以下を仮定します。
- 第 4 章「NxD対応済モデルの学習編」を実施済みであること
- 分散学習の基礎知識(第 5 章の内容)
問題 (38-50)
前々章では、tanuki-8b の学習を行いました。tanuki-8b のモデルアーキテクチャは、config.json にあるように「LlamaForCausalLM」です。このアーキテクチャは、 transformers ライブラリの
modeling_llama.py中に定義されています。
一方、NxDT ライブラリ中にも
modeling_llama.pyが存在します。こちらは、上記 transformers ライブラリの実装を Neuron チップ上での分散学習が可能な形に「移植」したものです。
世の中では日々新しいモデルアーキテクチャが登場しています。transformers ライブラリの models ディレクトリの中には、上記の LlamaForCausalLM 以外にも様々なアーキテクチャの定義コードが格納されています。しかしながら、これらの多くは、まだ NxDT 公式で移植版実装が提供されていません。
今回は NxDT 版の
modeling_llama.pyが公式提供されていましたが、これが仮に提供されていなかった場合を想定して、モデルの「移植」の練習をしてみましょう。
NxDT 版
modeling_llama.pyの中身を見るとわかりますが、基本的にはオリジナル(以下 HF 版と呼びます)のmodeling_llama.pyで定義されている層のクラスたちを継承し、必要な部分だけオーバーライド(上書き)している構成となっています。⚠️注意:このような構成のため、NxDTライブラリの挙動はtransformersライブラリのバージョンにかなり敏感になってしまっています。NxDTライブラリを使用する際は、transformersライブラリのバージョンは厳格に固定することをお勧めします。
-
HF 版の
modeling_llama.pyの中身を十分に理解してください(注意:HF版も日々更新されていくため、GitHubではなく、現在のPython環境にインストールされているtransformersライブラリの中にあるソースを確認するようにしてください)。どの層がどの層をどの順で呼び出すのか、そのツリー構造を正確に把握してください。Claude, Gemini, ChatGPT等のLLMにmodeling_llama.pyの全体をコンテキストとして与えた上で、「どのレイヤーがどのレイヤーを呼び出しているのか、全体の処理の流れがどうなっているのか、テキストで簡単なツリー図にまとめてもらえますか?」等と尋ねるのがお勧めです。-
LLMからの返答例
LlamaForCausalLM ├── LlamaModel │ ├── embed_tokens(埋め込み層) │ ├── Rotary Position Embedding(RoPE) │ ├── LlamaDecoderLayer(デコーダ層)× N層 │ │ ├── LlamaRMSNorm(正規化) │ │ ├── LlamaAttention(自己注意) │ │ │ ├── query, key, value 計算 │ │ │ ├── Rotary Position Embedding 適用 │ │ │ ├── アテンション重み計算(Scaled Dot-Product Attention) │ │ │ └── アテンション出力計算 │ │ ├── 残差接続(Residual Connection) │ │ ├── LlamaRMSNorm(正規化) │ │ ├── LlamaMLP(全結合層) │ │ │ ├── gate_proj & up_proj(次元変換) │ │ │ ├── 活性化関数(Activation Function) │ │ │ ├── down_proj(次元縮小) │ │ │ └── 出力 │ │ └── 残差接続(Residual Connection) │ ├── LlamaRMSNorm(正規化) │ └── 隠れ状態(Hidden States)を返す └── lm_head(ロジット計算)
-
-
train.shの階層にmodelsディレクトリを新規作成し、この中にmodeling_my_llama.pyというファイルを作成します。ここに移植コードを開発していきます。まずは、HF版のmodeling_llama.pyのクラスを継承して全くそのままとするところから始めましょう。内容を以下としてください。from transformers.models.llama.modeling_llama import ( LlamaForCausalLM as LlamaForCausalLMHF, LlamaRotaryEmbedding as LlamaRotaryEmbeddingHF, LlamaDecoderLayer as LlamaDecoderLayerHF, LlamaAttention as LlamaAttentionHF, LlamaRMSNorm as LlamaRMSNormHF, LlamaMLP as LlamaMLPHF ) class LlamaForCausalLM(LlamaForCausalLMHF): pass class LlamaRotaryEmbedding(LlamaRotaryEmbeddingHF): pass class LlamaDecoderLayer(LlamaDecoderLayerHF): pass class LlamaAttention(LlamaAttentionHF): pass class LlamaRMSNorm(LlamaRMSNormHF): pass class LlamaMLP(LlamaMLPHF): pass -
現時点で何がダメかを確認するため、とりあえず動かしてください。
- 訓練コードから、突貫で作成した上記のモデル定義を呼び出すように変更します。(参考:公式Docs)
-
training.pyを以下のように修正します:-
変更前
from neuronx_distributed_training.lightning_modules.model.hf_models.llama_model import ( HFLLamaModule, ) ...(中略)... model = HFLLamaModule(cfg, trainer) -
変更後
from models.my_llama_model import HFMyLLamaModule ...(中略)... model = HFMyLLamaModule(cfg, trainer)
-
-
modelsフォルダ内にファイルmy_llama_model.pyを新規作成し、その中で、上記でインポートするHFMyLLamaModuleを以下のように定義します。my_llama_model.py
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 import os import neuronx_distributed as nxd import torch from transformers import LlamaConfig import sys from neuronx_distributed.utils.utils import hardware from neuronx_distributed_training.utils import get_dtype, get_attribute_from_cfg from torch_neuronx.utils import get_platform_target from models.modeling_my_llama import ( LlamaAttention as CoreAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaRMSNorm, LlamaMLP, LlamaRotaryEmbedding ) from neuronx_distributed_training.lightning_modules.model.hf_models.base_model import BaseHfModel class HFMyLLamaModule(BaseHfModel): def _get_model(self): config = LlamaConfig.from_pretrained(self.config.model.model_config) config.use_cache = False config.return_dict = False config.sequence_parallel_enabled = self.config.distributed_strategy.get("sequence_parallel", False) config.qkv_linear = self.config.model.get("qkv_linear", False) config.fuse_qkv = self.config.model.get("fuse_qkv", True) config.kv_shared_group_size = self.config.distributed_strategy.get("kv_replicator", 1) config.max_position_embeddings = self.config.model.get("max_position_embeddings", config.max_position_embeddings) config.use_flash_attention = self.config.model.fusions.flash_attention config.use_ring_attention = get_attribute_from_cfg(self.config, 'ring_attention', False) hardware_type = hardware(get_platform_target()) if hardware_type==hardware.TRN1: config.lnc = self.config.trainer.get("lnc", 1) if hardware_type==hardware.TRN2: config.lnc = self.config.trainer.get("lnc", 2) if self.config.model.get('num_layers', -1) != -1: config.num_hidden_layers = self.config.model.get('num_layers') if self.config.model.get('hidden_size', -1) != -1: config.hidden_size = self.config.model.get('hidden_size') if self.config.model.get('rope_theta', -1) != -1: config.rope_theta = self.config.model.get('rope_theta') config.head_dim = get_attribute_from_cfg(self.config, 'hidden_size', config.hidden_size) // config.num_attention_heads # overriding head_dim value, which was set in transformers code config.transpose_nki_inputs = self.config.model.get('transpose_nki_inputs', True) # transpose_nki_inputs by default if get_attribute_from_cfg(self.config, "peft", False): lora_config = nxd.modules.lora.LoraConfig( lora_rank=get_attribute_from_cfg(self.config, 'lora_rank', 16), lora_alpha=get_attribute_from_cfg(self.config, 'lora_alpha', 32), lora_dropout=get_attribute_from_cfg(self.config, 'lora_dropout', 0.05), bias=get_attribute_from_cfg(self.config, 'lora_bias', "none"), lora_verbose=get_attribute_from_cfg(self.config, 'lora_verbose', True), target_modules=get_attribute_from_cfg(self.config, 'target_modules', ["qkv_proj"]), load_lora_from_ckpt=get_attribute_from_cfg(self.config, 'load_lora_from_ckpt', False), save_lora_base=get_attribute_from_cfg(self.config, 'save_lora_base', False), merge_lora=get_attribute_from_cfg(self.config, 'merge_lora', False), save_lora_config_adapter=get_attribute_from_cfg(self.config, 'save_lora_config_adapter', True), merge_sharded_lora=get_attribute_from_cfg(self.config, 'merge_sharded_lora', False), ) self.nxd_config["lora_config"] = lora_config if self.config.precision.type == "fp32": config.reduce_dtype = get_dtype(self.config.precision.get('parallel_layers_reduce_dtype', 'fp32')) # RS would be in fp32 as there is no implicit downcasting config.torch_dtype = torch.float32 else: config.reduce_dtype = torch.bfloat16 # default RS type, this wont get downcasted to anything else, so RS will happen at bf16 if get_dtype(self.config.precision.get('parallel_layers_reduce_dtype', 'bf16')) == torch.float32: config.reduce_dtype = torch.float64 config.torch_dtype = torch.bfloat16 leaf_module_cls = [LlamaRMSNorm.__name__, LlamaRotaryEmbedding.__name__] activation_recompute_modules = [] recompute_modules = self.config.model.get("activations_checkpoint_recompute", []) granularity = self.config.model.get("activations_checkpoint_granularity", None) if granularity == "selective": for module in recompute_modules: module_obj = getattr(sys.modules[__name__], module, None) if module_obj is not None: activation_recompute_modules.append(module_obj) elif granularity == "full": activation_recompute_modules = "full" elif not self.config.model.fusions.get("flash_attention", False): activation_recompute_modules.append(CoreAttention) # do CoreAttention checkpointing if flash_attention is off else: activation_recompute_modules = None self.nxd_config["activation_checkpoint_config"] = activation_recompute_modules self.nxd_config["pipeline_config"].update( { "transformer_layer_cls": LlamaDecoderLayer, "output_loss_value_spec": (True, False), "input_names": ["input_ids", "attention_mask", "labels"], "leaf_module_cls": leaf_module_cls, } ) include_buffers = True return nxd.initialize_parallel_model(self.nxd_config, self.model_provider_func, include_buffers, config) def model_provider_func(self, config): model = LlamaForCausalLM(config) # Here we make sure we use the same sine and cosine matrices for all layers. # Making use of same tensors would make the CSE algorithm eliminate the lookup call # from layers, keeping only lookup from first layer. # with torch.no_grad(): # cos, sin = self.get_sin_cos_matrix(config) # for layer in model.model.layers: # layer.self_attn.rotary_emb.cos_cached = cos # layer.self_attn.rotary_emb.sin_cached = sin if os.environ.get("XLA_DOWNCAST_BF16", None) == "0" and config.torch_dtype == torch.bfloat16: model = model.to(torch.bfloat16) return model def get_sin_cos_matrix(self, config): head_dim = config.hidden_size // config.num_attention_heads base = config.rope_theta inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) t = torch.arange(config.max_position_embeddings, dtype=inv_freq.dtype) freqs = torch.einsum("i,j->ij", t, inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) return emb.cos()[None, None, :, :].to(torch.float32), emb.sin()[None, None, :, :].to(torch.float32) def init_weights(self, module, device): """ Re-init weights after partition Referred from HF transformers https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L690 """ # Last else should always call super().init_weights() to allow initializing # pre-defined layers. for key, nested_module in module._modules.items(): if isinstance(nested_module, LlamaRotaryEmbedding): module._modules[key] = LlamaRotaryEmbedding(nested_module.config, device) if isinstance(module, LlamaRMSNorm): module.weight.data.fill_(1.0) else: super().init_weights(module, device)
-
- 訓練コードから、突貫で作成した上記のモデル定義を呼び出すように変更します。(参考:公式Docs)
-
AOTコンパイルの
sbatchコマンドを実施してください。ただし、チェックポイントの読み込みは行わないことにして、フルスクラッチ重みを利用する設定に変更して実施してください。(resume_from_checkpoint: null)データセットの前処理ループが終了後、直ちにOOMでプロセスが終了すると予想されます。並列化を行わないとメモリがデバイスに乗り切りません。並列化を行うためには、以下のようにモデル定義に修正を施す必要があります。
-
モデルに含まれる「パラメータ」をすべて列挙してください。また、それぞれの層のパラメータのサイズを確認してください。
解説
-
例えば以下のような方法で確認できます:
import torch from transformers import AutoModel model = AutoModel.from_pretrained("/fsx/models/Tanuki-8B-dpo-v1.0/") def print_param_shapes(model: torch.nn.Module): total_params = 0 for name, param in model.named_parameters(): print(f"{name:<60} {tuple(param.shape)}") total_params += param.numel() print(f"\nTotal parameters: {total_params:,}") # 実行 print_param_shapes(model) -
結果をまとめると以下のようになります(上記の出力を要約したもの)
model.embed_tokens.weight (65024, 4096) model.layers.*.self_attn.q_proj.weight (4096, 4096) model.layers.*.self_attn.k_proj.weight (1024, 4096) model.layers.*.self_attn.v_proj.weight (1024, 4096) model.layers.*.self_attn.o_proj.weight (4096, 4096) model.layers.*.mlp.gate_proj.weight (14336, 4096) model.layers.*.mlp.up_proj.weight (14336, 4096) model.layers.*.mlp.down_proj.weight (4096, 14336) model.layers.*.input_layernorm.weight (4096,) model.layers.*.post_attention_layernorm.weight (4096,) model.norm.weight (4096,) lm_head.weight (65024, 4096)
上記の中で、パラメータサイズが大きい
embed_tokens,self_attn.[qkvo]_proj,mlp.(gate|up|down)_proj,lm_headは、テンソル並列可能な形に変更する必要があります。 -
-
最初に、トークンIDをベクトルに変換する「埋め込み層」をテンソル並列可能な形に変更してください。これには
ParallelEmbedding層を使用します。- HF 版のソースを確認すると、
LlamaModelの__init__中でself.embed_tokensというnn.Embedding層がインスタンス化されています。ここをParallelEmbedding層で置き換えます。modeling_my_llama.pyの中で、LlamaModelの__init__関数を、以下の変更を適用した版でオーバーライドしてください(同時に、LlamaModelの呼び出し元であるLlamaForCausalLMの__init__も、新版のLlamaModelをインスタンス化するためにオーバーライドが必要になります。合わせて、各__init__冒頭のsuper().__init__の修正も適宜必要なのでご注意ください):-
変更前
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) -
変更後(
from neuronx_distributed.parallel_layers.layers import ParallelEmbeddingが必要)self.embed_tokens = ParallelEmbedding(config.vocab_size, config.hidden_size, self.padding_idx, sequence_parallel_enabled=self.config.sequence_parallel_enabled) -
解説
-
ParallelEmbedding層(公式Docs)(ソース)は、torch.nn.Embeddingと同じ役割の層ですが、そのパラメータテンソルは「語彙数の次元」でTPサイズに分割されて各デバイスに分散して保持されます。 -
sequence_parallel_enabledは、シーケンス並列の使用有無を設定するオプションです。-
False(デフォルト)の場合:シーケンス並列を使用しません。TP個すべてのデバイスに、本来のシェイプのテンソルが返されます。 -
Trueの場合:シーケンス並列を使用します。この場合、この層からの出力は「シーケンス並列モード」(前述)で返ってきます。すなわち、TP個それぞれのデバイスには、TP分割後のテンソルが返されます(一般にそれらは互いに異なります)。それらのシーケンス方向の次元は、本来の1/TP倍です。また、通常の (バッチサイズ, シーケンス長, 隠れ次元) という軸順と異なり、0-dimと1-dimが転置された (シーケンス長, バッチサイズ, 隠れ次元) の軸順のテンソルが返ります。
-
-
-
- HF 版のソースを確認すると、
-
MLP層
LlamaMLPをテンソル並列可能な形に変更してください。これにはColumnParallelLinear層とRowParallelLinear層を使用します。- HF 版のソースを確認すると、
LlamaMLPの__init__でself.gate_proj,self.up_proj,self.down_projという3つのnn.Linear層がインスタンス化されていますが、ここを以下で置き換えます。前項同様、LlamaMLPの__init__関数を修正版でオーバーライドしてください(また、呼び出し元のLlamaDecoderLayerの__init__およびその呼び出し元のLlamaForCausalLMの__init__もオーバーライドが必要です。以降、この括弧内と同様の注意書きは省略します):-
変更前
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) -
変更後(
from neuronx_distributed.parallel_layers.layers import ColumnParallelLinear, RowParallelLinearが必要)self.gate_proj = ColumnParallelLinear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, gather_output=False, sequence_parallel_enabled=self.config.sequence_parallel_enabled) self.up_proj = ColumnParallelLinear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, gather_output=False, sequence_parallel_enabled=self.config.sequence_parallel_enabled) self.down_proj = RowParallelLinear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias, input_is_parallel=True, sequence_parallel_enabled=self.config.sequence_parallel_enabled) -
解説
-
ColumnParallelLinear層(公式Docs)(ソース)は、torch.nn.Linearと同様の線形層ですが、そのパラメータテンソルは「出力次元」でTPサイズに分割されて各デバイスに分散して保持されます。-
gather_outputは、この層からの出力をTP個のデバイスで分散保持するか否かを設定するパラメータです。-
Falseの場合:分散保持されます。すなわち、TP個それぞれのデバイスには、TP分割後のテンソルが返されます(一般にそれらは互いに異なります)。その隠れ次元は、本来の1/TP倍です。 -
True(デフォルト)の場合:分散保持はされません。TP個すべてのデバイスに、本来の出力次元サイズの(分割されていない)テンソルが返されます。
-
-
sequence_parallel_enabledは、シーケンス並列の使用有無を設定するオプションです。-
False(デフォルト)の場合:シーケンス並列を使用しません。 -
Trueの場合:シーケンス並列を使用します。この層に入力されるテンソルは「シーケンス並列モード」で保持されていると仮定され、またこの層からの出力テンソルは「シーケンス並列モード」で返ってきます。
-
-
-
RowParallelLinear層(公式Docs)(ソース)は、torch.nn.Linearと同様の線形層ですが、そのパラメータテンソルは「入力次元」でTPサイズに分割されて各デバイスに分散して保持されます。-
input_is_parallelは、この層への入力がデバイス間で分散保持されているか否かを伝えるパラメータです。-
False(デフォルト)の場合:分散保持されていないと仮定されます。 -
Trueの場合:分散保持されていると仮定されます。入力するテンソルのサイズは、TP分割後のサイズである必要があります。
-
-
- 一般に、線形層が 2つ直列に続く場合、1つ目を
ColumnParallelLinear(gather_output=False)とし、2つ目をRowParallelLinear(input_is_parallel=True)とすることで、効率的な計算が可能です。これが「基本パターン」となります。以降の変更箇所でも、このパターンが現れます。
-
-
- HF 版のソースを確認すると、
-
最後の全結合層
lm_headをテンソル並列可能な形に変更してください。-
HF 版のソースを確認すると、
LlamaForCausalLMの__init__でself.lm_headというnn.Linear層がインスタンス化されていますが、これをColumnParallelLinear層で置き換えます。-
変更前
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) -
変更後
self.lm_head = ColumnParallelLinear(config.hidden_size, config.vocab_size, bias=False, gather_output=False, sequence_parallel_enabled=self.config.sequence_parallel_enabled)
-
-
今回は、TP分割が「合流」するポイントはクロスエントロピーを計算する直前となります。「TPを合流した後ソフトマックスクロスエントロピーを計算する関数」として、NxDライブラリに
parallel_cross_entropyが用意されています。LlamaForCausalLMのforwardの最後に loss を計算する部分がありますが、そこを以下のように修正します:-
変更前
loss = None if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) -
変更後(
from neuronx_distributed.parallel_layers.loss_functions import parallel_cross_entropyが必要)if self.config.sequence_parallel_enabled: logits = logits.transpose(0, 1).contiguous() loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].clone().contiguous() shift_labels = labels[..., 1:].contiguous() shift_logits = shift_logits.view(-1, shift_logits.size(-1)) shift_labels = shift_labels.view(-1) shift_labels = shift_labels.to(shift_logits.device) loss = parallel_cross_entropy(shift_logits, shift_labels) loss = torch.mean(loss) -
解説
-
parallel_cross_entropy関数(公式Docs)(ソース)は、「TP分割されている logits」と「(TP分割されていない)labels」からソフトマックスクロスエントロピーロスを計算します。 -
logits.shape == (バッチサイズ, シーケンス長, 語彙サイズ/TP),labels.shape == (バッチサイズ, シーケンス長)、あるいはlogits.shape == (バッチサイズ * シーケンス長, 語彙サイズ/TP),labels.shape == (バッチサイズ * シーケンス長)である必要があります。また、labels[..., i](labels[..., i+1]ではなく)に対する予測がlogits[..., i, :]であると解釈されます。 - シーケンス並列が有効の場合には、0-dim(バッチ次元)と 1-dim(シーケンス次元)が入れ替わって logits に保持されているため、最初に transpose を実施しています。
- 変更前→変更後で「インデクスずらし」のロジックが追加されているように見えますが、変更前のソースでは
self.loss_functionの内部でインデクスずらしが実行されているため、実質は同じです。
-
-
-
-
セルフアテンション層
LlamaAttentionをテンソル並列可能な形に変更してください。-
LlamaAttention層についての基礎知識- セルフアテンション層の基本構造についての知識は仮定しますが、その上で、
LlamaAttentionの特徴を補足します:-
LlamaAttentionでは Grouped Query Attention (GQA) がサポートされています。GQA とは、複数の Query ヘッドに対して Key/Value ヘッドをまとめて共有させることで、計算量やメモリ使用量を削減する仕組みです。通常の Multi-Head Attention では Q/K/V のヘッド数が同じですが、GQA では Q の表現力を保ちながら K/V の数を減らします。これにより、推論時の KV キャッシュの負担を大幅に軽減しつつ、性能をほぼ維持できます。- 例えば Qヘッド数が 32、KVヘッド数が 8 の場合、
Q0, ..., Q31とK0, ..., K7とV0, ..., V7が計算されますが、(Q0, K0, V0), (Q1, K0, V0), (Q2, K0, V0), (Q3, K0, V0), (Q4, K1, V1), (Q5, K1, V1), …, (Q30, K7, V7), (Q31, K7, V7)という32組からそれぞれの O が計算されます。
- 例えば Qヘッド数が 32、KVヘッド数が 8 の場合、
-
- セルフアテンション層の基本構造についての知識は仮定しますが、その上で、
-
LlamaAttentionは、大きく分けると以下の4パートに分かれます:- 入力 (hidden_states) から Q, K, V を計算するパート
- Q, K に RoPE (Rotary Position Embedding) を適用するパート
- Q, K, V から O を計算するパート
- O から出力(次の hidden_states)を計算するパート
変更の必要な点が非常に多いですが、1つずつ解説します。
- (i) 入力 (hidden_states) から Q, K, V を計算するパート
- Q, K, V を計算する 3 つの線形層の代わりに、
GQAQKVColumnParallelLinear層を使用します。-
__init__内-
変更前
self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias ) self.k_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.v_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) -
変更後(
from neuronx_distributed.modules.qkv_linear import GQAQKVColumnParallelLinearが必要)self.qkv_proj = GQAQKVColumnParallelLinear( config.hidden_size, [config.num_attention_heads * self.head_dim, config.num_key_value_heads * self.head_dim], bias=config.attention_bias, gather_output=False, kv_size_multiplier=self.config.kv_shared_group_size, fuse_qkv=self.config.fuse_qkv, sequence_parallel_enabled=self.config.sequence_parallel_enabled ) -
解説
-
GQAQKVColumnParallelLinear層(公式Docs)(ソース)は、入力から Q, K, V を線形変換により計算しますが、そのパラメータテンソルは「出力次元」でTPサイズに分割されて各デバイスに分散して保持されます。- 第二引数には
[q_projの出力次元, k_proj(v_proj)の出力次元]を指定します。 - 出力は
(query_states, key_states, value_states)の tuple で返ってきます。それぞれのテンソルの軸順は[バッチサイズ, シーケンス長, ヘッド数*ヘッド次元]です。 -
kv_size_multiplierに、先述のKV_REPLICATORの値(i.e. KVヘッドの重みを何倍に複製して保持するか)を指定します。 -
fuse_qkv(デフォルト: True) が Trueの場合、Q, K, V を計算する線形変換のパラメータは結合された状態で保持されます(パラメータ名は(weight|bias)_qkvとなります)。Falseの場合は、Q, K, Vごとに別々のパラメータとして保持されます(パラメータ名は(weight|bias)_(q|k|v)。前者の方が計算効率が良いです。
- 第二引数には
-
-
-
__init__内、変数定義変更-
変更前
self.num_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads -
変更後(要
import neuronx_distributed.parallel_layers.utils as neuronx_dist_utils)self.num_heads = neuronx_dist_utils.divide(config.num_attention_heads, get_tensor_model_parallel_size()) self.num_key_value_heads = neuronx_dist_utils.divide( config.num_key_value_heads * self.config.kv_shared_group_size, get_tensor_model_parallel_size() ) self.num_key_value_groups = self.num_heads // self.num_key_value_heads -
解説
-
self.num_(heads|key_value_heads|key_value_groups)はそれぞれ「注意ヘッド数(Qヘッド数)」「KVヘッド数」「KVグループ数」ですが、それぞれを「TP で割った後の値」に変更しています。これらはforwardの処理で参照します。- ヘルパー関数
neuronx_dist_utils.divide(x, y)(ソース)は、x // yと等価ですが、x % y != 0の場合に例外を送出する点のみ異なります。
- ヘルパー関数
-
-
-
forward内-
変更前
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) -
変更後
bsz, q_len, _ = hidden_states.size() if self.config.sequence_parallel_enabled: q_len, bsz, _ = hidden_states.size() q_len = q_len * get_tensor_model_parallel_size() query_states, key_states, value_states = self.qkv_proj(hidden_states) query_states, key_states, value_states, seq_len_dim_index = self.permute_qkv_for_attn( query_states, key_states, value_states, bsz, q_len, self.num_heads, self.num_key_value_heads, self.head_dim, self.config )さらに以下ヘルパー関数を LlamaAttention 内に定義します
def reshape_and_permute_states_for_fa(self, states, bsz, q_len, num_heads, head_dim, use_sequence_parallel): if use_sequence_parallel: return states.view(q_len, bsz, num_heads, head_dim).permute(1, 2, 3, 0) else: return states.view(bsz, q_len, num_heads, head_dim).permute(0, 2, 3, 1) def permute_qkv_for_attn( self, query_states, key_states, value_states, bsz, q_len, num_heads, num_key_value_heads, head_dim, config ): if config.transpose_nki_inputs and config.use_flash_attention: query_states = self.reshape_and_permute_states_for_fa(query_states, bsz, q_len, num_heads, head_dim, config.sequence_parallel_enabled) key_states = self.reshape_and_permute_states_for_fa(key_states, bsz, q_len, num_key_value_heads, head_dim, config.sequence_parallel_enabled) value_states = self.reshape_and_permute_states_for_fa(value_states, bsz, q_len, num_key_value_heads, head_dim, config.sequence_parallel_enabled) dim_index = -1 elif config.sequence_parallel_enabled: query_states = query_states.view(q_len, bsz, num_heads, head_dim).permute(1, 2, 0, 3) key_states = key_states.view(q_len, bsz, num_key_value_heads, head_dim).permute(1, 2, 0, 3) value_states = value_states.view(q_len, bsz, num_key_value_heads, head_dim).permute(1, 2, 0, 3) dim_index = -2 else: query_states = query_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, num_key_value_heads, head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, num_key_value_heads, head_dim).transpose(1, 2) dim_index = -2 return query_states, key_states, value_states, dim_index -
解説
-
GQAQKVColumnParallelLinear層から出力された(query|key|value)_statesの軸順を、Attention のコア部分(Q, K, V から O を計算する部分)を計算しやすい順番に変更しています。具体的には、以下の軸順に変更しています:- config のフラグ
transpose_nki_inputsとuse_flash_attentionの両方が True である場合:[バッチサイズ, ヘッド数, ヘッド次元, シーケンス長](dim_index = -2) - 上記以外:
[バッチサイズ, ヘッド数, シーケンス長, ヘッド次元](dim_index = -1)
- config のフラグ
-
-
-
- Q, K, V を計算する 3 つの線形層の代わりに、
- (ii) Q, K に RoPE (Rotary Position Embedding) を適用するパート
-
インポート
-
変更前
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb -
変更後
from neuronx_distributed.overrides.transformer_overrides import apply_rotary_pos_emb
-
-
forward内-
変更前
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) -
変更後
query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin, None, self.config.use_flash_attention, self.config.transpose_nki_inputs )
-
-
解説
-
apply_rotary_pos_embはRoPEを計算する関数ですが、これをHF版からNxD版に差し替えています。NxD版の実装は、use_flash_attention==Trueかつtranspose_nki_inputs==Trueであるケース(query_state,key_statesの軸順が前述の通り入れ替わっている)に対応している点以外はHF版と同じです。
-
-
- (iii) Q, K, V から O を計算するパート
- まず K, V について、
KV_REPLICATORを反映する(すなわち、複製を行って「Q ヘッド数」に揃える)必要があります。-
追加
# repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) -
解説
-
repeat_kv関数は、シェイプが(バッチサイズ, KVヘッド数, シーケンス長, ヘッド次元)のテンソルを受け取り、(バッチサイズ, KVヘッド数 * 複製数, シーケンス長, ヘッド次元)にして返します(1-dim が元々[A, B, C, D]となっている場合、[A, A, A, B, B, B, C, C, C, D, D, D]となって返ってきます)。 - (なお、HF版ではこのKV複製処理は次項の「変更前」コードの
eager_attention_forward関数内で行われています)
-
-
- いよいよコア部分である Q, K, V → O の計算です。これを GPU で実施する際には、GPUの特性に依存した Flash Attention という手法がデファクトで使用され、素朴に計算するよりもメモリ使用量を著しく削減できます。これの Neuron チップ版が
nki_flash_attn_funcとして実装されています。-
変更前
attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): logger.warning_once( "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) else: attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, ) -
変更後(要
from neuronx_distributed.kernels.flash_attn import nki_flash_attn_func)attn_output = nki_flash_attn_func(query_states, key_states, value_states, self.config.lnc, transpose_nki_inputs=self.config.transpose_nki_inputs) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}" ) -
解説
-
nki_flash_attn_func(ソース)は、Neuron チップ向けの Flash Attention 実装です。Neuron チップの低レイヤーな制御が可能な、NKI (Neuron Kernel Interface) によって実装されています。 - 入力される
(query|key|value)_statesの軸順は以下である必要があります:-
transpose_nki_inputs==Trueの場合:[バッチサイズ, ヘッド数, ヘッド次元, シーケンス長] -
transpose_nki_inputs==Falseの場合:[バッチサイズ, ヘッド数, シーケンス長, ヘッド次元]
-
- 出力される
attn_outputの軸順は、いつでも[バッチサイズ, ヘッド数, シーケンス長, ヘッド次元]となります。
-
-
- まず K, V について、
- (iv) O から出力(次の hidden_states)を計算するパート
-
__init__内-
変更前
self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) -
変更後
self.o_proj = RowParallelLinear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias, input_is_parallel=True, sequence_parallel_enabled=self.config.sequence_parallel_enabled )
-
-
forward内、o_proj層に通す直前の reshape-
変更前
attn_output = attn_output.reshape(*input_shape, -1).contiguous() -
変更後
if self.config.sequence_parallel_enabled: attn_output = attn_output.permute(2, 0, 1, 3) attn_output = attn_output.reshape(q_len, bsz, self.hidden_size // get_tensor_model_parallel_size()) else: attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size // get_tensor_model_parallel_size()) -
解説
-
nki_flash_attn_funcからの出力テンソルを、通常のアクティベーションの軸順(シーケンス並列有りの場合は[シーケンス長, バッチサイズ, ヘッド数*ヘッド次元]、無しの場合は[バッチサイズ, シーケンス長, ヘッド数*ヘッド次元])になるようにtransposeとreshapeを実施しています。(TP並列はまだ合流していないため、隠れ次元が TP サイズで割られていることに注意してください。この後のo_projで合流します)
-
-
-
-
-
「AOTコンパイル」を再実行してください。
- もしエラーが出た場合、デバッグが必要です。一般的なアドバイスを以下に記します。
- デバッグのTips
- 冗長なログファイルからエラーの根本原因を見つけるコツ
- ログはしばしば数千行〜数万行となり、さらに並列で走る多数のプロセスからのログが混じり合うため、ログ行の順番もしばしば乱れます。エラーで落ちた場合に、エラーの原因を表すログ行をスムーズに特定することが重要となります。
- 最も冒頭に登場する
errorという文字列を検索:非常にシンプルな方法ですが、多くの場合これでエラーの原因に到達できます。- 上記ではっきりしない場合、冒頭に現れる
Error:やError |を検索すると多くのケースでエラーの原因に到達することができます。(前者は Python Error を、後者は Neuron コンパイラからのエラーを引っ掛けられます)
- 上記ではっきりしない場合、冒頭に現れる
- ログ全体をLLMに読ませてエラーの原因を調べさせる:非常に有効です。全体だと行数が多すぎて読ませられない場合は、上記方法で行範囲を絞ってからLLMに読ませると良いでしょう。
- 最も冒頭に登場する
- ログはしばしば数千行〜数万行となり、さらに並列で走る多数のプロセスからのログが混じり合うため、ログ行の順番もしばしば乱れます。エラーで落ちた場合に、エラーの原因を表すログ行をスムーズに特定することが重要となります。
-
printの代わりにxm.master_print-
print関数でデバッグ情報を出力させようとすると、プログラムの並列実行数だけ出力が重複してしまい、視認性が悪くなります。xm.master_print(要import torch_xla.core.xla_model as xm)を使用すると、マスターのプロセスからのみ出力が行われるため、デバッグログがすっきりします。
-
- 「問題のある層」の特定方法
- コンパイラが出すエラーからは、どの層に問題があるかは直ちにはわかりません。
- 層と層の間にprint関数を挟みまくるような方法では、問題のある層を特定できません。
- 問題のある層を特定するには「層の短絡」がしばしば有効です。例えば、
LlamaModelの中でLlamaDecoderLayerを繰り返すforループを丸ごとスキップするように改変してみます。これでエラーが解消されれば、原因はLlamaDecoderLayerの中にあることがわかりますし、エラーが解消されない場合は、最初のエンベディング、あるいは最後の線形層やloss計算に問題があることがわかります。
- テンソルを覗き見する際の注意
- 変数に格納されているテンソルの値を確認するために、デバッグ目的で
print(あるいはxm.master_print)すると、第1章で解説した通り、その時点で遅延評価が走ってしまいます。デバッグコードの挿入前と挿入後とでコンパイルのタイミングが変化することで、エラー内容が変化してしまう場合があります。 - 一方、テンソルの
.shape,.dtype,.device等だけの確認であれば、遅延評価は走らず、上記のような問題は生じません。
- 変数に格納されているテンソルの値を確認するために、デバッグ目的で
- 冗長なログファイルからエラーの根本原因を見つけるコツ
- よくあるコンパイラのエラーメッセージ
-
Estimated peak HBM usage (18.179819) exceeds 16GB. Neff won't be able to load on chip- 「それぞれのNeuronコアにモデルが載りきらない」ことを表しています。メモリ使用量を減らす(最大シーケンス長やバッチサイズを小さく設定する、同アーキテクチャのより小規模なモデルに変更する)、モデル並列度を高める(TPやPPを大きくする)等を行う必要があります。
-
Couldn't color the DRAM even with 100GB of DRAM space assumption, model needs too much HBM memory !- 上記同様、モデルが載りきらないことを表しています。
-
DRAM usage for Internal DRAM tensor exceeds 16GB of device space limit, cannot fit into device, model requires too much HBM memory !- 上記同様、モデルが載りきらないことを表しています。
-
Internal tensorizer error: VectorizeDMA:Illegal after shrink dst!- 特定の変数に格納されるテンソルのシェイプが、処理のたびに可変になっているようなケースで、上記エラーが生じる場合があります。
-
RuntimeError: (1, 32768, 1, 80) and (1, 32768, 1, 80)- シェイプの等しいテンソルAとBについて
A * Bを計算する際に、それらの shape が揃っておらず計算不可能なとき、上記のようなエラーメッセージが表示されます。ただし、このエラーメッセージは誤っており、本来ならば「Aのshape」と「Bのshape」を並べるべきところに、「Aのshape」を2回並べてしまっていることに留意する必要があります。
- シェイプの等しいテンソルAとBについて
-
CCOM WARN No transport found between devices 8 and 7. Possible replica group misconfiguration- コア間の通信に問題がある場合に表示されるエラーです。Neuronチップ内のコア間が全結合ではないことに起因して、TPサイズ・PPサイズ・KV_REPLICATORの組み合わせによっては、本エラーが生じる場合があります。
-
ERROR: Unsupported operation: mhlo.set_dimension_size- 現状、Neuron コンパイラは「torch.tensor に対して可能な全ての(最小単位の)操作」に対応している訳ではありません。Unsupported operation で始まるエラーが生じた場合は、対応していない操作が含まれています。この場合、対応している操作で代替する必要があります。
-
Number of instructions (8146371) is over the threshold (5000000). - Compile under --optlevel=1 to create smaller subgraphs or use pipeline parallelism.- サイズの巨大なテンソルを載せようとすると上記エラーが出る場合があります。
-
RuntimeError: Bad StatusOr access: RESOURCE_EXHAUSTED: LoadCollectives: error condition NRT_RESOURCE == rt_status:- Neuron SDKのバージョンをアップデートすることで、本エラーが解消する場合があります。
-
- デバッグのTips
- もしエラーが出た場合、デバッグが必要です。一般的なアドバイスを以下に記します。
-
「学習」を再実行してください。
最後に、チェックポイントの読み込みに対応しましょう。
-
チェックポイントの読み込みを行う設定に戻して、再度 AOT コンパイル・学習を実行してください。
-
実はこのままでは、チェックポイント読み込み時に以下のPythonエラーが生じます:
RuntimeError: Missing keys when loading state dictionary: model.layers.0.mlp.gate_proj.weight, model.layers.0.mlp.up_proj.weight, model.layers.1.mlp.gate_proj.weight, ...(中略)... model.layers.1.mlp.up_proj.weight,model.layers.30.mlp.gate_proj.weight, model.layers.30.mlp.up_proj.weight, model.layers.31.mlp.gate_proj.weight, model.layers.31.mlp.up_proj.weight -
これは、HF→NxDのチェックポイント変換処理が「MLP 層内の
gate_proj線形層とup_proj線形層はまとめてgate_up_proj線形層とする」仕様となっているにも関わらず、上記で移植実装を行った時にそれが考慮されていないためです。その結果として、モデル定義にgate_projやup_proj層が存在するにも関わらず、読み込もうとしているチェックポイントファイルにはそれらのデータが存在しないため、上記エラー文言となります)。 -
以下のように修正を行います:
-
LlamaMLPの__init__内を以下のように変更します:-
変更前
self.gate_proj = ColumnParallelLinear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, gather_output=False, sequence_parallel_enabled=self.config.sequence_parallel_enabled) self.up_proj = ColumnParallelLinear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, gather_output=False, sequence_parallel_enabled=self.config.sequence_parallel_enabled) self.act_fn = ACT2FN[config.hidden_act] -
変更後
self.gate_up_proj = ColumnParallelLinear( self.hidden_size, 2 * self.intermediate_size, stride=2, bias=config.mlp_bias, gather_output=False, sequence_parallel_enabled=self.config.sequence_parallel_enabled, ) self.activation_multiply = ActivationMultiplyMLP(config) -
解説
-
ColumnParallelLinearのstrideパラメータ(デフォルト: 1)が設定されていると、パラメータテンソルのTP分割のされ方が変化します。具体的には、分割方向の次元全体を「stride 個の区間」に等分した後、それぞれを TP 個に等分します(結果として、それぞれのデバイスが担当するインデクスは stride個の区間に分かれます)。 - stride=2 と設定したことにより、各デバイスに格納される出テンソルは、その前半が
gate_projの計算結果となり、その後半はup_projの計算結果となります。
-
-
-
追加で以下を定義します:
class ActivationMultiplyMLP(torch.nn.Module): def __init__(self, config): nn.Module.__init__(self) self.act_fn = ACT2FN[config.hidden_act] self.split_size = config.intermediate_size // get_tensor_model_parallel_size() def forward(self, x): gate_proj, up_proj = x.split(self.split_size, dim=2) intermediate_states = self.act_fn(gate_proj) * up_proj return intermediate_states -
LlamaMLPのforward内を以下のように修正します:-
変更前
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) -
変更後
intermediate_states = self.activation_multiply(self.gate_up_proj(x)) down_proj = self.down_proj(intermediate_states)
-
-
-
-
最後に、移植が正常に行えているかどうかを確認するため、テスト用の入力系列を1つ固定して、それを「移植前のモデル」と「移植後のモデル」の
forwardに通してみて、同じ logits が得られることを確認してください。- evalモードで評価することに留意してください。
- 数値誤差があるため、完全に一致するとは限りません。ただし、明らかに異なる logits が出てしまう場合は、途中のどこかの層の移植に問題があると考えられます。途中の層時点での hidden_states を確認するなどして、デバッグを実施してください。
ここまで実行できれば、モデルの基本的な移植方法は一通り抑えられているはずです。
-
本50本ノックの内容を監修くださった AWS の常世様に、感謝を申し上げます。 ↩︎
Discussion