OpenCALM-7Bのコードリーディング(基本編)
概要
CyberAgentさんが公開してくれているLLMモデルであるOpenCALMを動かして単純な質問をした際、どの様なコードが動いているのか読んでみたいと思います。
実行例
google colab proで実行します
※無料版ではメモリが足りずに動きませんでした
ライブラリのインストール
!pip install torch transformers accelerate
コード
以下は
に記載されているサンプルコードになります。import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("cyberagent/open-calm-7b", device_map="auto", torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained("cyberagent/open-calm-7b")
inputs = tokenizer("AIによって私達の暮らしは、", return_tensors="pt").to(model.device)
with torch.no_grad():
tokens = model.generate(
**inputs,
max_new_tokens=64,
do_sample=True,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.05,
pad_token_id=tokenizer.pad_token_id,
)
output = tokenizer.decode(tokens[0], skip_special_tokens=True)
print(output)
以下の様な出力がされます。
※途中で切れているように見えるのはmax_new_tokens=64のためです。
AIによって私達の暮らしは、日々大きく変化しています。しかし、人工知能やロボット技術が発達しても、人は人の手で物を創り出したい欲求を捨てられないのではないでしょうか?そして、それはきっとこの先も変わらないと思います。
そのような背景から、当社では製品の企画・開発・販売という一連の流れを、一貫して手掛けることで、よりスピーディーに製品をお客さまにお届け
コードを読む
AutoModelForCausalLM.from_pretrained
transformersを使ってモデルを生成するコードについてみていきます。
transformersとはhuggingface(機械学習モデルのgithubみたいなサイト)からモデルをダウンロードするためのライブラリです。
model = AutoModelForCausalLM.from_pretrained("cyberagent/open-calm-7b", device_map="auto", torch_dtype=torch.float16)
コードを読む前にモデルのサマリーを出力してみます。
from torchinfo import summary
summary(model=model, depth=100)
結果
==========================================================================
Layer (type:depth-idx) Param #
================================================================================
GPTNeoXForCausalLM --
├─GPTNeoXModel: 1-1 --
│ └─Embedding: 2-1 213,909,504
│ └─Dropout: 2-2 --
│ └─ModuleList: 2-3 --
│ │ └─GPTNeoXLayer: 3-1 --
│ │ │ └─LayerNorm: 4-1 8,192
│ │ │ └─LayerNorm: 4-2 8,192
│ │ │ └─Dropout: 4-3 --
│ │ │ └─Dropout: 4-4 --
│ │ │ └─GPTNeoXAttention: 4-5 --
│ │ │ │ └─GPTNeoXRotaryEmbedding: 5-1 --
│ │ │ │ └─Linear: 5-2 50,343,936
│ │ │ │ └─Linear: 5-3 16,781,312
│ │ │ │ └─Dropout: 5-4 --
│ │ │ └─GPTNeoXMLP: 4-6 --
│ │ │ │ └─Linear: 5-5 67,125,248
│ │ │ │ └─Linear: 5-6 67,112,960
│ │ │ │ └─GELUActivation: 5-7 --
│ │ └─GPTNeoXLayer: 3-2 --
│ │ │ └─LayerNorm: 4-7 8,192
│ │ │ └─LayerNorm: 4-8 8,192
│ │ │ └─Dropout: 4-9 --
│ │ │ └─Dropout: 4-10 --
│ │ │ └─GPTNeoXAttention: 4-11 --
│ │ │ │ └─GPTNeoXRotaryEmbedding: 5-8 --
│ │ │ │ └─Linear: 5-9 50,343,936
│ │ │ │ └─Linear: 5-10 16,781,312
│ │ │ │ └─Dropout: 5-11 --
│ │ │ └─GPTNeoXMLP: 4-12 --
│ │ │ │ └─Linear: 5-12 67,125,248
│ │ │ │ └─Linear: 5-13 67,112,960
│ │ │ │ └─GELUActivation: 5-14 --
│ │ └─GPTNeoXLayer: 3-3 --
│ │ │ └─LayerNorm: 4-13 8,192
│ │ │ └─LayerNorm: 4-14 8,192
│ │ │ └─Dropout: 4-15 --
│ │ │ └─Dropout: 4-16 --
│ │ │ └─GPTNeoXAttention: 4-17 --
│ │ │ │ └─GPTNeoXRotaryEmbedding: 5-15 --
│ │ │ │ └─Linear: 5-16 50,343,936
│ │ │ │ └─Linear: 5-17 16,781,312
│ │ │ │ └─Dropout: 5-18 --
│ │ │ └─GPTNeoXMLP: 4-18 --
│ │ │ │ └─Linear: 5-19 67,125,248
│ │ │ │ └─Linear: 5-20 67,112,960
│ │ │ │ └─GELUActivation: 5-21 --
│ │ └─GPTNeoXLayer: 3-4 --
│ │ │ └─LayerNorm: 4-19 8,192
│ │ │ └─LayerNorm: 4-20 8,192
│ │ │ └─Dropout: 4-21 --
│ │ │ └─Dropout: 4-22 --
│ │ │ └─GPTNeoXAttention: 4-23 --
│ │ │ │ └─GPTNeoXRotaryEmbedding: 5-22 --
│ │ │ │ └─Linear: 5-23 50,343,936
│ │ │ │ └─Linear: 5-24 16,781,312
│ │ │ │ └─Dropout: 5-25 --
│ │ │ └─GPTNeoXMLP: 4-24 --
│ │ │ │ └─Linear: 5-26 67,125,248
│ │ │ │ └─Linear: 5-27 67,112,960
│ │ │ │ └─GELUActivation: 5-28 --
│ │ └─GPTNeoXLayer: 3-5 --
│ │ │ └─LayerNorm: 4-25 8,192
│ │ │ └─LayerNorm: 4-26 8,192
│ │ │ └─Dropout: 4-27 --
│ │ │ └─Dropout: 4-28 --
│ │ │ └─GPTNeoXAttention: 4-29 --
│ │ │ │ └─GPTNeoXRotaryEmbedding: 5-29 --
│ │ │ │ └─Linear: 5-30 50,343,936
│ │ │ │ └─Linear: 5-31 16,781,312
│ │ │ │ └─Dropout: 5-32 --
│ │ │ └─GPTNeoXMLP: 4-30 --
│ │ │ │ └─Linear: 5-33 67,125,248
│ │ │ │ └─Linear: 5-34 67,112,960
│ │ │ │ └─GELUActivation: 5-35 --
│ │ └─GPTNeoXLayer: 3-6 --
│ │ │ └─LayerNorm: 4-31 8,192
│ │ │ └─LayerNorm: 4-32 8,192
│ │ │ └─Dropout: 4-33 --
│ │ │ └─Dropout: 4-34 --
│ │ │ └─GPTNeoXAttention: 4-35 --
│ │ │ │ └─GPTNeoXRotaryEmbedding: 5-36 --
│ │ │ │ └─Linear: 5-37 50,343,936
│ │ │ │ └─Linear: 5-38 16,781,312
│ │ │ │ └─Dropout: 5-39 --
│ │ │ └─GPTNeoXMLP: 4-36 --
│ │ │ │ └─Linear: 5-40 67,125,248
│ │ │ │ └─Linear: 5-41 67,112,960
│ │ │ │ └─GELUActivation: 5-42 --
│ │ └─GPTNeoXLayer: 3-7 --
│ │ │ └─LayerNorm: 4-37 8,192
│ │ │ └─LayerNorm: 4-38 8,192
│ │ │ └─Dropout: 4-39 --
│ │ │ └─Dropout: 4-40 --
│ │ │ └─GPTNeoXAttention: 4-41 --
│ │ │ │ └─GPTNeoXRotaryEmbedding: 5-43 --
│ │ │ │ └─Linear: 5-44 50,343,936
│ │ │ │ └─Linear: 5-45 16,781,312
│ │ │ │ └─Dropout: 5-46 --
│ │ │ └─GPTNeoXMLP: 4-42 --
│ │ │ │ └─Linear: 5-47 67,125,248
│ │ │ │ └─Linear: 5-48 67,112,960
│ │ │ │ └─GELUActivation: 5-49 --
│ │ └─GPTNeoXLayer: 3-8 --
│ │ │ └─LayerNorm: 4-43 8,192
│ │ │ └─LayerNorm: 4-44 8,192
│ │ │ └─Dropout: 4-45 --
│ │ │ └─Dropout: 4-46 --
│ │ │ └─GPTNeoXAttention: 4-47 --
│ │ │ │ └─GPTNeoXRotaryEmbedding: 5-50 --
│ │ │ │ └─Linear: 5-51 50,343,936
│ │ │ │ └─Linear: 5-52 16,781,312
│ │ │ │ └─Dropout: 5-53 --
│ │ │ └─GPTNeoXMLP: 4-48 --
│ │ │ │ └─Linear: 5-54 67,125,248
│ │ │ │ └─Linear: 5-55 67,112,960
│ │ │ │ └─GELUActivation: 5-56 --
│ │ └─GPTNeoXLayer: 3-9 --
│ │ │ └─LayerNorm: 4-49 8,192
│ │ │ └─LayerNorm: 4-50 8,192
│ │ │ └─Dropout: 4-51 --
│ │ │ └─Dropout: 4-52 --
│ │ │ └─GPTNeoXAttention: 4-53 --
│ │ │ │ └─GPTNeoXRotaryEmbedding: 5-57 --
│ │ │ │ └─Linear: 5-58 50,343,936
│ │ │ │ └─Linear: 5-59 16,781,312
│ │ │ │ └─Dropout: 5-60 --
│ │ │ └─GPTNeoXMLP: 4-54 --
│ │ │ │ └─Linear: 5-61 67,125,248
│ │ │ │ └─Linear: 5-62 67,112,960
│ │ │ │ └─GELUActivation: 5-63 --
│ │ └─GPTNeoXLayer: 3-10 --
│ │ │ └─LayerNorm: 4-55 8,192
│ │ │ └─LayerNorm: 4-56 8,192
│ │ │ └─Dropout: 4-57 --
│ │ │ └─Dropout: 4-58 --
│ │ │ └─GPTNeoXAttention: 4-59 --
│ │ │ │ └─GPTNeoXRotaryEmbedding: 5-64 --
│ │ │ │ └─Linear: 5-65 50,343,936
│ │ │ │ └─Linear: 5-66 16,781,312
│ │ │ │ └─Dropout: 5-67 --
│ │ │ └─GPTNeoXMLP: 4-60 --
│ │ │ │ └─Linear: 5-68 67,125,248
│ │ │ │ └─Linear: 5-69 67,112,960
│ │ │ │ └─GELUActivation: 5-70 --
│ │ └─GPTNeoXLayer: 3-11 --
│ │ │ └─LayerNorm: 4-61 8,192
│ │ │ └─LayerNorm: 4-62 8,192
│ │ │ └─Dropout: 4-63 --
│ │ │ └─Dropout: 4-64 --
│ │ │ └─GPTNeoXAttention: 4-65 --
│ │ │ │ └─GPTNeoXRotaryEmbedding: 5-71 --
│ │ │ │ └─Linear: 5-72 50,343,936
│ │ │ │ └─Linear: 5-73 16,781,312
│ │ │ │ └─Dropout: 5-74 --
│ │ │ └─GPTNeoXMLP: 4-66 --
│ │ │ │ └─Linear: 5-75 67,125,248
│ │ │ │ └─Linear: 5-76 67,112,960
│ │ │ │ └─GELUActivation: 5-77 --
│ │ └─GPTNeoXLayer: 3-12 --
│ │ │ └─LayerNorm: 4-67 8,192
│ │ │ └─LayerNorm: 4-68 8,192
│ │ │ └─Dropout: 4-69 --
│ │ │ └─Dropout: 4-70 --
│ │ │ └─GPTNeoXAttention: 4-71 --
│ │ │ │ └─GPTNeoXRotaryEmbedding: 5-78 --
│ │ │ │ └─Linear: 5-79 50,343,936
│ │ │ │ └─Linear: 5-80 16,781,312
│ │ │ │ └─Dropout: 5-81 --
│ │ │ └─GPTNeoXMLP: 4-72 --
│ │ │ │ └─Linear: 5-82 67,125,248
│ │ │ │ └─Linear: 5-83 67,112,960
│ │ │ │ └─GELUActivation: 5-84 --
│ │ └─GPTNeoXLayer: 3-13 --
│ │ │ └─LayerNorm: 4-73 8,192
│ │ │ └─LayerNorm: 4-74 8,192
│ │ │ └─Dropout: 4-75 --
│ │ │ └─Dropout: 4-76 --
│ │ │ └─GPTNeoXAttention: 4-77 --
│ │ │ │ └─GPTNeoXRotaryEmbedding: 5-85 --
│ │ │ │ └─Linear: 5-86 50,343,936
│ │ │ │ └─Linear: 5-87 16,781,312
│ │ │ │ └─Dropout: 5-88 --
│ │ │ └─GPTNeoXMLP: 4-78 --
│ │ │ │ └─Linear: 5-89 67,125,248
│ │ │ │ └─Linear: 5-90 67,112,960
│ │ │ │ └─GELUActivation: 5-91 --
│ │ └─GPTNeoXLayer: 3-14 --
│ │ │ └─LayerNorm: 4-79 8,192
│ │ │ └─LayerNorm: 4-80 8,192
│ │ │ └─Dropout: 4-81 --
│ │ │ └─Dropout: 4-82 --
│ │ │ └─GPTNeoXAttention: 4-83 --
│ │ │ │ └─GPTNeoXRotaryEmbedding: 5-92 --
│ │ │ │ └─Linear: 5-93 50,343,936
│ │ │ │ └─Linear: 5-94 16,781,312
│ │ │ │ └─Dropout: 5-95 --
│ │ │ └─GPTNeoXMLP: 4-84 --
│ │ │ │ └─Linear: 5-96 67,125,248
│ │ │ │ └─Linear: 5-97 67,112,960
│ │ │ │ └─GELUActivation: 5-98 --
│ │ └─GPTNeoXLayer: 3-15 --
│ │ │ └─LayerNorm: 4-85 8,192
│ │ │ └─LayerNorm: 4-86 8,192
│ │ │ └─Dropout: 4-87 --
│ │ │ └─Dropout: 4-88 --
│ │ │ └─GPTNeoXAttention: 4-89 --
│ │ │ │ └─GPTNeoXRotaryEmbedding: 5-99 --
│ │ │ │ └─Linear: 5-100 50,343,936
│ │ │ │ └─Linear: 5-101 16,781,312
│ │ │ │ └─Dropout: 5-102 --
│ │ │ └─GPTNeoXMLP: 4-90 --
│ │ │ │ └─Linear: 5-103 67,125,248
│ │ │ │ └─Linear: 5-104 67,112,960
│ │ │ │ └─GELUActivation: 5-105 --
│ │ └─GPTNeoXLayer: 3-16 --
│ │ │ └─LayerNorm: 4-91 8,192
│ │ │ └─LayerNorm: 4-92 8,192
│ │ │ └─Dropout: 4-93 --
│ │ │ └─Dropout: 4-94 --
│ │ │ └─GPTNeoXAttention: 4-95 --
│ │ │ │ └─GPTNeoXRotaryEmbedding: 5-106 --
│ │ │ │ └─Linear: 5-107 50,343,936
│ │ │ │ └─Linear: 5-108 16,781,312
│ │ │ │ └─Dropout: 5-109 --
│ │ │ └─GPTNeoXMLP: 4-96 --
│ │ │ │ └─Linear: 5-110 67,125,248
│ │ │ │ └─Linear: 5-111 67,112,960
│ │ │ │ └─GELUActivation: 5-112 --
│ │ └─GPTNeoXLayer: 3-17 --
│ │ │ └─LayerNorm: 4-97 8,192
│ │ │ └─LayerNorm: 4-98 8,192
│ │ │ └─Dropout: 4-99 --
│ │ │ └─Dropout: 4-100 --
│ │ │ └─GPTNeoXAttention: 4-101 --
│ │ │ │ └─GPTNeoXRotaryEmbedding: 5-113 --
│ │ │ │ └─Linear: 5-114 50,343,936
│ │ │ │ └─Linear: 5-115 16,781,312
│ │ │ │ └─Dropout: 5-116 --
│ │ │ └─GPTNeoXMLP: 4-102 --
│ │ │ │ └─Linear: 5-117 67,125,248
│ │ │ │ └─Linear: 5-118 67,112,960
│ │ │ │ └─GELUActivation: 5-119 --
│ │ └─GPTNeoXLayer: 3-18 --
│ │ │ └─LayerNorm: 4-103 8,192
│ │ │ └─LayerNorm: 4-104 8,192
│ │ │ └─Dropout: 4-105 --
│ │ │ └─Dropout: 4-106 --
│ │ │ └─GPTNeoXAttention: 4-107 --
│ │ │ │ └─GPTNeoXRotaryEmbedding: 5-120 --
│ │ │ │ └─Linear: 5-121 50,343,936
│ │ │ │ └─Linear: 5-122 16,781,312
│ │ │ │ └─Dropout: 5-123 --
│ │ │ └─GPTNeoXMLP: 4-108 --
│ │ │ │ └─Linear: 5-124 67,125,248
│ │ │ │ └─Linear: 5-125 67,112,960
│ │ │ │ └─GELUActivation: 5-126 --
│ │ └─GPTNeoXLayer: 3-19 --
│ │ │ └─LayerNorm: 4-109 8,192
│ │ │ └─LayerNorm: 4-110 8,192
│ │ │ └─Dropout: 4-111 --
│ │ │ └─Dropout: 4-112 --
│ │ │ └─GPTNeoXAttention: 4-113 --
│ │ │ │ └─GPTNeoXRotaryEmbedding: 5-127 --
│ │ │ │ └─Linear: 5-128 50,343,936
│ │ │ │ └─Linear: 5-129 16,781,312
│ │ │ │ └─Dropout: 5-130 --
│ │ │ └─GPTNeoXMLP: 4-114 --
│ │ │ │ └─Linear: 5-131 67,125,248
│ │ │ │ └─Linear: 5-132 67,112,960
│ │ │ │ └─GELUActivation: 5-133 --
│ │ └─GPTNeoXLayer: 3-20 --
│ │ │ └─LayerNorm: 4-115 8,192
│ │ │ └─LayerNorm: 4-116 8,192
│ │ │ └─Dropout: 4-117 --
│ │ │ └─Dropout: 4-118 --
│ │ │ └─GPTNeoXAttention: 4-119 --
│ │ │ │ └─GPTNeoXRotaryEmbedding: 5-134 --
│ │ │ │ └─Linear: 5-135 50,343,936
│ │ │ │ └─Linear: 5-136 16,781,312
│ │ │ │ └─Dropout: 5-137 --
│ │ │ └─GPTNeoXMLP: 4-120 --
│ │ │ │ └─Linear: 5-138 67,125,248
│ │ │ │ └─Linear: 5-139 67,112,960
│ │ │ │ └─GELUActivation: 5-140 --
│ │ └─GPTNeoXLayer: 3-21 --
│ │ │ └─LayerNorm: 4-121 8,192
│ │ │ └─LayerNorm: 4-122 8,192
│ │ │ └─Dropout: 4-123 --
│ │ │ └─Dropout: 4-124 --
│ │ │ └─GPTNeoXAttention: 4-125 --
│ │ │ │ └─GPTNeoXRotaryEmbedding: 5-141 --
│ │ │ │ └─Linear: 5-142 50,343,936
│ │ │ │ └─Linear: 5-143 16,781,312
│ │ │ │ └─Dropout: 5-144 --
│ │ │ └─GPTNeoXMLP: 4-126 --
│ │ │ │ └─Linear: 5-145 67,125,248
│ │ │ │ └─Linear: 5-146 67,112,960
│ │ │ │ └─GELUActivation: 5-147 --
│ │ └─GPTNeoXLayer: 3-22 --
│ │ │ └─LayerNorm: 4-127 8,192
│ │ │ └─LayerNorm: 4-128 8,192
│ │ │ └─Dropout: 4-129 --
│ │ │ └─Dropout: 4-130 --
│ │ │ └─GPTNeoXAttention: 4-131 --
│ │ │ │ └─GPTNeoXRotaryEmbedding: 5-148 --
│ │ │ │ └─Linear: 5-149 50,343,936
│ │ │ │ └─Linear: 5-150 16,781,312
│ │ │ │ └─Dropout: 5-151 --
│ │ │ └─GPTNeoXMLP: 4-132 --
│ │ │ │ └─Linear: 5-152 67,125,248
│ │ │ │ └─Linear: 5-153 67,112,960
│ │ │ │ └─GELUActivation: 5-154 --
│ │ └─GPTNeoXLayer: 3-23 --
│ │ │ └─LayerNorm: 4-133 8,192
│ │ │ └─LayerNorm: 4-134 8,192
│ │ │ └─Dropout: 4-135 --
│ │ │ └─Dropout: 4-136 --
│ │ │ └─GPTNeoXAttention: 4-137 --
│ │ │ │ └─GPTNeoXRotaryEmbedding: 5-155 --
│ │ │ │ └─Linear: 5-156 50,343,936
│ │ │ │ └─Linear: 5-157 16,781,312
│ │ │ │ └─Dropout: 5-158 --
│ │ │ └─GPTNeoXMLP: 4-138 --
│ │ │ │ └─Linear: 5-159 67,125,248
│ │ │ │ └─Linear: 5-160 67,112,960
│ │ │ │ └─GELUActivation: 5-161 --
│ │ └─GPTNeoXLayer: 3-24 --
│ │ │ └─LayerNorm: 4-139 8,192
│ │ │ └─LayerNorm: 4-140 8,192
│ │ │ └─Dropout: 4-141 --
│ │ │ └─Dropout: 4-142 --
│ │ │ └─GPTNeoXAttention: 4-143 --
│ │ │ │ └─GPTNeoXRotaryEmbedding: 5-162 --
│ │ │ │ └─Linear: 5-163 50,343,936
│ │ │ │ └─Linear: 5-164 16,781,312
│ │ │ │ └─Dropout: 5-165 --
│ │ │ └─GPTNeoXMLP: 4-144 --
│ │ │ │ └─Linear: 5-166 67,125,248
│ │ │ │ └─Linear: 5-167 67,112,960
│ │ │ │ └─GELUActivation: 5-168 --
│ │ └─GPTNeoXLayer: 3-25 --
│ │ │ └─LayerNorm: 4-145 8,192
│ │ │ └─LayerNorm: 4-146 8,192
│ │ │ └─Dropout: 4-147 --
│ │ │ └─Dropout: 4-148 --
│ │ │ └─GPTNeoXAttention: 4-149 --
│ │ │ │ └─GPTNeoXRotaryEmbedding: 5-169 --
│ │ │ │ └─Linear: 5-170 50,343,936
│ │ │ │ └─Linear: 5-171 16,781,312
│ │ │ │ └─Dropout: 5-172 --
│ │ │ └─GPTNeoXMLP: 4-150 --
│ │ │ │ └─Linear: 5-173 67,125,248
│ │ │ │ └─Linear: 5-174 67,112,960
│ │ │ │ └─GELUActivation: 5-175 --
│ │ └─GPTNeoXLayer: 3-26 --
│ │ │ └─LayerNorm: 4-151 8,192
│ │ │ └─LayerNorm: 4-152 8,192
│ │ │ └─Dropout: 4-153 --
│ │ │ └─Dropout: 4-154 --
│ │ │ └─GPTNeoXAttention: 4-155 --
│ │ │ │ └─GPTNeoXRotaryEmbedding: 5-176 --
│ │ │ │ └─Linear: 5-177 50,343,936
│ │ │ │ └─Linear: 5-178 16,781,312
│ │ │ │ └─Dropout: 5-179 --
│ │ │ └─GPTNeoXMLP: 4-156 --
│ │ │ │ └─Linear: 5-180 67,125,248
│ │ │ │ └─Linear: 5-181 67,112,960
│ │ │ │ └─GELUActivation: 5-182 --
│ │ └─GPTNeoXLayer: 3-27 --
│ │ │ └─LayerNorm: 4-157 8,192
│ │ │ └─LayerNorm: 4-158 8,192
│ │ │ └─Dropout: 4-159 --
│ │ │ └─Dropout: 4-160 --
│ │ │ └─GPTNeoXAttention: 4-161 --
│ │ │ │ └─GPTNeoXRotaryEmbedding: 5-183 --
│ │ │ │ └─Linear: 5-184 50,343,936
│ │ │ │ └─Linear: 5-185 16,781,312
│ │ │ │ └─Dropout: 5-186 --
│ │ │ └─GPTNeoXMLP: 4-162 --
│ │ │ │ └─Linear: 5-187 67,125,248
│ │ │ │ └─Linear: 5-188 67,112,960
│ │ │ │ └─GELUActivation: 5-189 --
│ │ └─GPTNeoXLayer: 3-28 --
│ │ │ └─LayerNorm: 4-163 8,192
│ │ │ └─LayerNorm: 4-164 8,192
│ │ │ └─Dropout: 4-165 --
│ │ │ └─Dropout: 4-166 --
│ │ │ └─GPTNeoXAttention: 4-167 --
│ │ │ │ └─GPTNeoXRotaryEmbedding: 5-190 --
│ │ │ │ └─Linear: 5-191 50,343,936
│ │ │ │ └─Linear: 5-192 16,781,312
│ │ │ │ └─Dropout: 5-193 --
│ │ │ └─GPTNeoXMLP: 4-168 --
│ │ │ │ └─Linear: 5-194 67,125,248
│ │ │ │ └─Linear: 5-195 67,112,960
│ │ │ │ └─GELUActivation: 5-196 --
│ │ └─GPTNeoXLayer: 3-29 --
│ │ │ └─LayerNorm: 4-169 8,192
│ │ │ └─LayerNorm: 4-170 8,192
│ │ │ └─Dropout: 4-171 --
│ │ │ └─Dropout: 4-172 --
│ │ │ └─GPTNeoXAttention: 4-173 --
│ │ │ │ └─GPTNeoXRotaryEmbedding: 5-197 --
│ │ │ │ └─Linear: 5-198 50,343,936
│ │ │ │ └─Linear: 5-199 16,781,312
│ │ │ │ └─Dropout: 5-200 --
│ │ │ └─GPTNeoXMLP: 4-174 --
│ │ │ │ └─Linear: 5-201 67,125,248
│ │ │ │ └─Linear: 5-202 67,112,960
│ │ │ │ └─GELUActivation: 5-203 --
│ │ └─GPTNeoXLayer: 3-30 --
│ │ │ └─LayerNorm: 4-175 8,192
│ │ │ └─LayerNorm: 4-176 8,192
│ │ │ └─Dropout: 4-177 --
│ │ │ └─Dropout: 4-178 --
│ │ │ └─GPTNeoXAttention: 4-179 --
│ │ │ │ └─GPTNeoXRotaryEmbedding: 5-204 --
│ │ │ │ └─Linear: 5-205 50,343,936
│ │ │ │ └─Linear: 5-206 16,781,312
│ │ │ │ └─Dropout: 5-207 --
│ │ │ └─GPTNeoXMLP: 4-180 --
│ │ │ │ └─Linear: 5-208 67,125,248
│ │ │ │ └─Linear: 5-209 67,112,960
│ │ │ │ └─GELUActivation: 5-210 --
│ │ └─GPTNeoXLayer: 3-31 --
│ │ │ └─LayerNorm: 4-181 8,192
│ │ │ └─LayerNorm: 4-182 8,192
│ │ │ └─Dropout: 4-183 --
│ │ │ └─Dropout: 4-184 --
│ │ │ └─GPTNeoXAttention: 4-185 --
│ │ │ │ └─GPTNeoXRotaryEmbedding: 5-211 --
│ │ │ │ └─Linear: 5-212 50,343,936
│ │ │ │ └─Linear: 5-213 16,781,312
│ │ │ │ └─Dropout: 5-214 --
│ │ │ └─GPTNeoXMLP: 4-186 --
│ │ │ │ └─Linear: 5-215 67,125,248
│ │ │ │ └─Linear: 5-216 67,112,960
│ │ │ │ └─GELUActivation: 5-217 --
│ │ └─GPTNeoXLayer: 3-32 --
│ │ │ └─LayerNorm: 4-187 8,192
│ │ │ └─LayerNorm: 4-188 8,192
│ │ │ └─Dropout: 4-189 --
│ │ │ └─Dropout: 4-190 --
│ │ │ └─GPTNeoXAttention: 4-191 --
│ │ │ │ └─GPTNeoXRotaryEmbedding: 5-218 --
│ │ │ │ └─Linear: 5-219 50,343,936
│ │ │ │ └─Linear: 5-220 16,781,312
│ │ │ │ └─Dropout: 5-221 --
│ │ │ └─GPTNeoXMLP: 4-192 --
│ │ │ │ └─Linear: 5-222 67,125,248
│ │ │ │ └─Linear: 5-223 67,112,960
│ │ │ │ └─GELUActivation: 5-224 --
│ └─LayerNorm: 2-4 8,192
├─Linear: 1-2 213,909,504
================================================================================
Total params: 6,871,982,080
Trainable params: 6,871,982,080
Non-trainable params: 0
================================================================================
GPTNeoXForCausalLMというアーキテクチャが使われているようです。
68億のパラメータであることもわかります。
Total params: 6,871,982,080
AutoModelForCausalLMについて見ていきます。
auto_class_updateがクラスを生成しているようです。
サンプルコードで叩くのはここでセットしているクラスメソッドです。
cls.from_pretrained = classmethod(from_pretrained)
親クラスのメソッドをコピーしてから調整しているようです。
from_pretrained = copy_func(_BaseAutoModelClass.from_pretrained)
from_pretrainedは学習済みのモデルを返すので重要です。
pretrained_model_name_or_path
には"cyberagent/open-calm-7b"が入ります。
結果このブロックから学習済みモデルを返しています。
elif type(config) in cls._model_mapping.keys():
model_class = _get_model_class(config, cls._model_mapping)
return model_class.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
)
model_class.from_pretrained
は以下のメソッドになります。
長すぎて読みきれないです。
ただ結局返しているのはmodel(GPTNeoXForCausalLM)になります。
重みは
※load_state_dictはpytorchの仕組みでモデルに対して重みをロードするもの。
AutoTokenizer.from_pretrained
tokenizerを生成するコードについてみていきます。
tokenizerは文字列をTokenに分割し、それぞれのIDを返してくれるものになります。
tokenizer = AutoTokenizer.from_pretrained("cyberagent/open-calm-7b")
tokenizerをprintすると以下になります。
GPTNeoXTokenizerFast(name_or_path='cyberagent/open-calm-7b', vocab_size=52000, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|padding|>'}, clean_up_tokenization_spaces=True)
文字列をトークン化すると
inputs = tokenizer("AIによって私達の暮らしは、", return_tensors="pt").to(model.device)
print(inputs)
以下のように変換されます。
{'input_ids': tensor([[ 4215, 930, 18030, 16205, 257, 245]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]], device='cuda:0')}
AutoTokenizer.from_pretrainedを使用してtokenizerを取得すると、
以下のメソッドが実行されるようです。こちらはあまり掘り下げないでおきます。
with torch.no_grad():
こちらはpytorchの機能で、ブロック内では勾配の計算をしないというものになります。
学習しないときはこのブロックに入れることで処理が軽くなります。
model.generate
ここで入力されたトークンに対して答えとなるトークンを推論しています。
tokens = model.generate(
**inputs,
max_new_tokens=64,
do_sample=True,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.05,
pad_token_id=tokenizer.pad_token_id,
)
普通はpytorchだと以下のように呼び出すのがお決まりなのですが、transformers経由だと少し勝手が違うみたいです。
outputs = model(inputs)
generateコードを見てみます。こちらも長いです。
最終的にこのブロックに入ります。
elif is_sample_gen_mode:
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config)
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_return_sequences,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
# 13. run sample
return self.sample(
input_ids,
logits_processor=logits_processor,
logits_warper=logits_warper,
stopping_criteria=stopping_criteria,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
streamer=streamer,
**model_kwargs,
)
実際のモデルに推論させているのはここです。最大64回通ります。
max_new_tokens=64に対応しているのだと思います。それより短い場合もありそうです。
tokenizer.decode
最後は推論されたトークンのIDリストを日本語に変換して終了です。
output = tokenizer.decode(tokens[0], skip_special_tokens=True)
まとめ
transformersのコードは結構マッチョで読みづらかったです。
今後はGPTNeoXForCausalLMの中身、
fine-turning時にどのように動作しているのか等を調査したいと思います。
Discussion