概要
https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoModelForCausalLM.from_pretrained
AutoModelForCausalLM.from_pretrained
のコードを読んでいく。
実行例
Swallow-7bモデルを使用したケースを想定。
from transformers import AutoModelForCausalLM
model_name = "tokyotech-llm/Swallow-7b-instruct-hf"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map="auto")
重要そうなメソッド
AutoModelForCausalLM
_BaseAutoModelClassを継承している
https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/models/auto/modeling_auto.py#L1336
_BaseAutoModelClass.from_pretrained
ここから開始
https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/models/auto/auto_factory.py#L442-L443
config_fileの取得
configを指定しない場合はデフォルトのファイルを使用する
https://github.com/huggingface/transformers/blob/7d312ad2e9473cd3a0ea3e9b206b8ed3c147e9be/src/transformers/models/auto/auto_factory.py#L482-L489
https://huggingface.co/tokyotech-llm/Swallow-7b-instruct-hf/blob/main/config.json
peftモデルを指定した場合
指定したモデルやディレクトリがpeftのファイルなら考慮して動作する。
https://github.com/huggingface/transformers/blob/b8b16475d41b66ab0e1fe9d1cb82bbff65e5f6d6/src/transformers/models/auto/auto_factory.py#L500-L509
modelクラスの取得
https://github.com/huggingface/transformers/blob/b8b16475d41b66ab0e1fe9d1cb82bbff65e5f6d6/src/transformers/models/auto/auto_factory.py#L560
model_class.from_pretrained
モデルのインスタンス化
https://github.com/huggingface/transformers/blob/b8b16475d41b66ab0e1fe9d1cb82bbff65e5f6d6/src/transformers/models/auto/auto_factory.py#L561-L563
_get_model_class
モデルクラスの取得
https://github.com/huggingface/transformers/blob/b8b16475d41b66ab0e1fe9d1cb82bbff65e5f6d6/src/transformers/models/auto/auto_factory.py#L380
マッピングはここ
https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/models/auto/modeling_auto.py#L406
今回はLlamaForCausalLMを使用する
LlamaForCausalLM
クラス定義
PreTrainedModelを継承している
https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/models/llama/modeling_llama.py#L1070
PreTrainedModel.from_pretrained
https://github.com/huggingface/transformers/blob/1c81132e80478e278681686fe44dfec793d5dee9/src/transformers/modeling_utils.py#L2578-L2592
config_fileの取得
AutoModelForCausalLMでも似たようなのを見た気が
https://github.com/huggingface/transformers/blob/1c81132e80478e278681686fe44dfec793d5dee9/src/transformers/modeling_utils.py#L2872-L2888
peftモデルを指定した場合
これもAutoModelForCausalLMで見ました
https://github.com/huggingface/transformers/blob/1c81132e80478e278681686fe44dfec793d5dee9/src/transformers/modeling_utils.py#L2893-L2910
device_mapの設定
device_mapの指定方法色々
https://github.com/huggingface/transformers/blob/1c81132e80478e278681686fe44dfec793d5dee9/src/transformers/modeling_utils.py#L2914-L2931
量子化周りの設定
https://github.com/huggingface/transformers/blob/1c81132e80478e278681686fe44dfec793d5dee9/src/transformers/modeling_utils.py#L2950-L2966
configの取得
https://github.com/huggingface/transformers/blob/1c81132e80478e278681686fe44dfec793d5dee9/src/transformers/modeling_utils.py#L2981-L2995
AutoHfQuantizerの設定
量子化の設定
https://github.com/huggingface/transformers/blob/1c81132e80478e278681686fe44dfec793d5dee9/src/transformers/modeling_utils.py#L3013-L3019
pretrained_model_name_or_pathに対する分岐
今回はelseに入る
https://github.com/huggingface/transformers/blob/1c81132e80478e278681686fe44dfec793d5dee9/src/transformers/modeling_utils.py#L3137
resolved_archive_file
結局以下のファイルへのパスが入る
https://huggingface.co/tokyotech-llm/Swallow-7b-instruct-hf/blob/main/model.safetensors.index.json
重みのダウンロード
model.safetensors.index.json
を参考に重みファイルのダウンロードを行う
https://github.com/huggingface/transformers/blob/1c81132e80478e278681686fe44dfec793d5dee9/src/transformers/modeling_utils.py#L3264-L3277
torch dtypeの設定
https://github.com/huggingface/transformers/blob/1c81132e80478e278681686fe44dfec793d5dee9/src/transformers/modeling_utils.py#L3315-L3338
deepspeed zero3の設定
https://github.com/huggingface/transformers/blob/1c81132e80478e278681686fe44dfec793d5dee9/src/transformers/modeling_utils.py#L3360-L3364
flash_attentionの設定
https://github.com/huggingface/transformers/blob/1c81132e80478e278681686fe44dfec793d5dee9/src/transformers/modeling_utils.py#L3369-L3371
モデルのインスタンス化
https://github.com/huggingface/transformers/blob/1c81132e80478e278681686fe44dfec793d5dee9/src/transformers/modeling_utils.py#L3373-L3375
config
LlamaConfig {
"_name_or_path": "tokyotech-llm/Swallow-7b-instruct-hf",
"architectures": [
"LlamaForCausalLM"
],
"attention_bias": false,
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 11008,
"max_position_embeddings": 4096,
"max_sequence_length": 4096,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 32,
"num_key_value_heads": 32,
"pad_token_id": 0,
"pretraining_tp": 1,
"rms_norm_eps": 1e-05,
"rope_scaling": null,
"rope_theta": 10000.0,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.35.2",
"use_cache": true,
"vocab_size": 43176
}
model_args
model_kwargs
量子化の設定
https://github.com/huggingface/transformers/blob/1c81132e80478e278681686fe44dfec793d5dee9/src/transformers/modeling_utils.py#L3389-L3391
devive_mapの設定
https://github.com/huggingface/transformers/blob/1c81132e80478e278681686fe44dfec793d5dee9/src/transformers/modeling_utils.py#L3399-L3458
重みのメモリロード
https://github.com/huggingface/transformers/blob/1c81132e80478e278681686fe44dfec793d5dee9/src/transformers/modeling_utils.py#L3495-L3518
device_mapへのdispatch
https://github.com/huggingface/transformers/blob/1c81132e80478e278681686fe44dfec793d5dee9/src/transformers/modeling_utils.py#L3558
量子化の適応
https://github.com/huggingface/transformers/blob/1c81132e80478e278681686fe44dfec793d5dee9/src/transformers/modeling_utils.py#L3561
peftの適応
https://github.com/huggingface/transformers/blob/1c81132e80478e278681686fe44dfec793d5dee9/src/transformers/modeling_utils.py#L3565-L3570
LlamaForCausalLM.init
https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/models/llama/modeling_llama.py#L1073
LlamaModel.__init__の呼び出し
https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/models/llama/modeling_llama.py#L1075
LlamaModel.init
https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/models/llama/modeling_llama.py#L891
nn.Embedding
https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/models/llama/modeling_llama.py#L896
(config.vocab_size, config.hidden_size, self.padding_idx)
nn.Embedding
についてはこちらを参照
https://qiita.com/typecprint/items/35c4cc9e73da49695f2b
sum(p.numel() for p in self.embed_tokens.parameters())
nn.ModuleList
https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/models/llama/modeling_llama.py#L897-L899
sum(p.numel() for p in self.layers.parameters())
LlamaRMSNorm
https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/models/llama/modeling_llama.py#L900
(config.hidden_size, config.rms_norm_eps)
sum(p.numel() for p in self.norm.parameters())
causal_maskの設定
マスク用の行列登録
https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/models/llama/modeling_llama.py#L903-L905
(config.max_position_embeddings, config.max_position_embeddings)
torch.triu(causal_mask, diagonal=1)
tensor([[0, 1, 1, ..., 1, 1, 1],
[0, 0, 1, ..., 1, 1, 1],
[0, 0, 0, ..., 1, 1, 1],
...,
[0, 0, 0, ..., 0, 1, 1],
[0, 0, 0, ..., 0, 0, 1],
[0, 0, 0, ..., 0, 0, 0]])
torch.triu(causal_mask, diagonal=1).size()
post_init
重みの初期化をしている
https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/models/llama/modeling_llama.py#L907
LlamaDecoderLayer
https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/models/llama/modeling_llama.py#L671
self_attn
https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/models/llama/modeling_llama.py#L675
LLAMA_ATTENTION_CLASSES[config._attn_implementation]
<class 'transformers.models.llama.modeling_llama.LlamaAttention'>
mlp
https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/models/llama/modeling_llama.py#L677
https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/models/llama/modeling_llama.py#L678
post_attention_layernorm
https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/models/llama/modeling_llama.py#L679
LlamaAttention
https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/models/llama/modeling_llama.py#L238
QKVO
https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/models/llama/modeling_llama.py#L265-L268
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=4096, bias=False)
(v_proj): Linear(in_features=4096, out_features=4096, bias=False)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
LlamaRotaryEmbeddingの初期化
https://github.com/huggingface/transformers/blob/ce4fff0be7f6464d713f7ac3e0bbaafbc6959ae5/src/transformers/models/llama/modeling_llama.py#L269
_load_pretrained_model
https://github.com/huggingface/transformers/blob/594c1277b2fcc1c1aed252d320359101409e0407/src/transformers/modeling_utils.py#L3584-L3602
load_state_dict
メモリに重みのロード
https://github.com/huggingface/transformers/blob/594c1277b2fcc1c1aed252d320359101409e0407/src/transformers/modeling_utils.py#L3903
モデルに重みをロード
https://github.com/huggingface/transformers/blob/594c1277b2fcc1c1aed252d320359101409e0407/src/transformers/modeling_utils.py#L3926-L3943
まとめ
なんとなくの雰囲気はつかめましたが、長すぎてコードが追いきれなかったです
Discussion