🧗‍♀️

AutoModelForCausalLM.from_pretrainedコードリーディング

2024/02/21に公開

概要

https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoModelForCausalLM.from_pretrained

AutoModelForCausalLM.from_pretrained
のコードを読んでいく。

実行例

Swallow-7bモデルを使用したケースを想定。

from transformers import AutoModelForCausalLM

model_name = "tokyotech-llm/Swallow-7b-instruct-hf"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map="auto")

重要そうなメソッド

AutoModelForCausalLM

_BaseAutoModelClassを継承している
https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/models/auto/modeling_auto.py#L1336

_BaseAutoModelClass.from_pretrained

ここから開始
https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/models/auto/auto_factory.py#L442-L443

config_fileの取得

configを指定しない場合はデフォルトのファイルを使用する
https://github.com/huggingface/transformers/blob/7d312ad2e9473cd3a0ea3e9b206b8ed3c147e9be/src/transformers/models/auto/auto_factory.py#L482-L489

https://huggingface.co/tokyotech-llm/Swallow-7b-instruct-hf/blob/main/config.json

peftモデルを指定した場合

指定したモデルやディレクトリがpeftのファイルなら考慮して動作する。
https://github.com/huggingface/transformers/blob/b8b16475d41b66ab0e1fe9d1cb82bbff65e5f6d6/src/transformers/models/auto/auto_factory.py#L500-L509

modelクラスの取得

https://github.com/huggingface/transformers/blob/b8b16475d41b66ab0e1fe9d1cb82bbff65e5f6d6/src/transformers/models/auto/auto_factory.py#L560

model_class.from_pretrained

モデルのインスタンス化
https://github.com/huggingface/transformers/blob/b8b16475d41b66ab0e1fe9d1cb82bbff65e5f6d6/src/transformers/models/auto/auto_factory.py#L561-L563

_get_model_class

モデルクラスの取得
https://github.com/huggingface/transformers/blob/b8b16475d41b66ab0e1fe9d1cb82bbff65e5f6d6/src/transformers/models/auto/auto_factory.py#L380

マッピングはここ
https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/models/auto/modeling_auto.py#L406

今回はLlamaForCausalLMを使用する

LlamaForCausalLM

クラス定義
PreTrainedModelを継承している
https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/models/llama/modeling_llama.py#L1070

PreTrainedModel.from_pretrained

https://github.com/huggingface/transformers/blob/1c81132e80478e278681686fe44dfec793d5dee9/src/transformers/modeling_utils.py#L2578-L2592

config_fileの取得

AutoModelForCausalLMでも似たようなのを見た気が
https://github.com/huggingface/transformers/blob/1c81132e80478e278681686fe44dfec793d5dee9/src/transformers/modeling_utils.py#L2872-L2888

peftモデルを指定した場合

これもAutoModelForCausalLMで見ました
https://github.com/huggingface/transformers/blob/1c81132e80478e278681686fe44dfec793d5dee9/src/transformers/modeling_utils.py#L2893-L2910

device_mapの設定

device_mapの指定方法色々
https://github.com/huggingface/transformers/blob/1c81132e80478e278681686fe44dfec793d5dee9/src/transformers/modeling_utils.py#L2914-L2931

量子化周りの設定

https://github.com/huggingface/transformers/blob/1c81132e80478e278681686fe44dfec793d5dee9/src/transformers/modeling_utils.py#L2950-L2966

configの取得

https://github.com/huggingface/transformers/blob/1c81132e80478e278681686fe44dfec793d5dee9/src/transformers/modeling_utils.py#L2981-L2995

AutoHfQuantizerの設定

量子化の設定
https://github.com/huggingface/transformers/blob/1c81132e80478e278681686fe44dfec793d5dee9/src/transformers/modeling_utils.py#L3013-L3019

pretrained_model_name_or_pathに対する分岐

今回はelseに入る
https://github.com/huggingface/transformers/blob/1c81132e80478e278681686fe44dfec793d5dee9/src/transformers/modeling_utils.py#L3137

resolved_archive_file

結局以下のファイルへのパスが入る
https://huggingface.co/tokyotech-llm/Swallow-7b-instruct-hf/blob/main/model.safetensors.index.json

重みのダウンロード

model.safetensors.index.json
を参考に重みファイルのダウンロードを行う
https://github.com/huggingface/transformers/blob/1c81132e80478e278681686fe44dfec793d5dee9/src/transformers/modeling_utils.py#L3264-L3277

torch dtypeの設定

https://github.com/huggingface/transformers/blob/1c81132e80478e278681686fe44dfec793d5dee9/src/transformers/modeling_utils.py#L3315-L3338

deepspeed zero3の設定

https://github.com/huggingface/transformers/blob/1c81132e80478e278681686fe44dfec793d5dee9/src/transformers/modeling_utils.py#L3360-L3364

flash_attentionの設定

https://github.com/huggingface/transformers/blob/1c81132e80478e278681686fe44dfec793d5dee9/src/transformers/modeling_utils.py#L3369-L3371

モデルのインスタンス化

https://github.com/huggingface/transformers/blob/1c81132e80478e278681686fe44dfec793d5dee9/src/transformers/modeling_utils.py#L3373-L3375

config

LlamaConfig {
  "_name_or_path": "tokyotech-llm/Swallow-7b-instruct-hf",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 11008,
  "max_position_embeddings": 4096,
  "max_sequence_length": 4096,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 32,
  "pad_token_id": 0,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.35.2",
  "use_cache": true,
  "vocab_size": 43176
}

model_args

()

model_kwargs

{}

量子化の設定

https://github.com/huggingface/transformers/blob/1c81132e80478e278681686fe44dfec793d5dee9/src/transformers/modeling_utils.py#L3389-L3391

devive_mapの設定

https://github.com/huggingface/transformers/blob/1c81132e80478e278681686fe44dfec793d5dee9/src/transformers/modeling_utils.py#L3399-L3458

重みのメモリロード

https://github.com/huggingface/transformers/blob/1c81132e80478e278681686fe44dfec793d5dee9/src/transformers/modeling_utils.py#L3495-L3518

device_mapへのdispatch

https://github.com/huggingface/transformers/blob/1c81132e80478e278681686fe44dfec793d5dee9/src/transformers/modeling_utils.py#L3558

量子化の適応

https://github.com/huggingface/transformers/blob/1c81132e80478e278681686fe44dfec793d5dee9/src/transformers/modeling_utils.py#L3561

peftの適応

https://github.com/huggingface/transformers/blob/1c81132e80478e278681686fe44dfec793d5dee9/src/transformers/modeling_utils.py#L3565-L3570

LlamaForCausalLM.init

https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/models/llama/modeling_llama.py#L1073

LlamaModel.__init__の呼び出し

https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/models/llama/modeling_llama.py#L1075

LlamaModel.init

https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/models/llama/modeling_llama.py#L891

nn.Embedding

https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/models/llama/modeling_llama.py#L896

(config.vocab_size, config.hidden_size, self.padding_idx)
(43176, 4096, 0)

nn.Embedding
についてはこちらを参照
https://qiita.com/typecprint/items/35c4cc9e73da49695f2b

sum(p.numel() for p in self.embed_tokens.parameters())
176848896

nn.ModuleList

https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/models/llama/modeling_llama.py#L897-L899

config.num_hidden_layers
32
sum(p.numel() for p in self.layers.parameters())
6476267520

LlamaRMSNorm

https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/models/llama/modeling_llama.py#L900

(config.hidden_size, config.rms_norm_eps)
(4096, 1e-05)
sum(p.numel() for p in self.norm.parameters())
4096

causal_maskの設定

マスク用の行列登録

https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/models/llama/modeling_llama.py#L903-L905

(config.max_position_embeddings, config.max_position_embeddings)
(4096, 4096)
torch.triu(causal_mask, diagonal=1)
tensor([[0, 1, 1,  ..., 1, 1, 1],
        [0, 0, 1,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        ...,
        [0, 0, 0,  ..., 0, 1, 1],
        [0, 0, 0,  ..., 0, 0, 1],
        [0, 0, 0,  ..., 0, 0, 0]])
torch.triu(causal_mask, diagonal=1).size()
torch.Size([4096, 4096])

post_init

重みの初期化をしている
https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/models/llama/modeling_llama.py#L907

LlamaDecoderLayer

https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/models/llama/modeling_llama.py#L671

self_attn

https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/models/llama/modeling_llama.py#L675

LLAMA_ATTENTION_CLASSES[config._attn_implementation]
<class 'transformers.models.llama.modeling_llama.LlamaAttention'>

mlp

https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/models/llama/modeling_llama.py#L677

input_layernorm

https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/models/llama/modeling_llama.py#L678

post_attention_layernorm

https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/models/llama/modeling_llama.py#L679

LlamaAttention

https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/models/llama/modeling_llama.py#L238

QKVO

https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/models/llama/modeling_llama.py#L265-L268

  (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
  (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
  (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
  (o_proj): Linear(in_features=4096, out_features=4096, bias=False)

LlamaRotaryEmbeddingの初期化

https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/models/llama/modeling_llama.py#L269

_load_pretrained_model

https://github.com/huggingface/transformers/blob/594c1277b2fcc1c1aed252d320359101409e0407/src/transformers/modeling_utils.py#L3584-L3602

load_state_dict

メモリに重みのロード
https://github.com/huggingface/transformers/blob/594c1277b2fcc1c1aed252d320359101409e0407/src/transformers/modeling_utils.py#L3903

モデルに重みをロード

https://github.com/huggingface/transformers/blob/594c1277b2fcc1c1aed252d320359101409e0407/src/transformers/modeling_utils.py#L3926-L3943

まとめ

なんとなくの雰囲気はつかめましたが、長すぎてコードが追いきれなかったです

Discussion