🐣

OpenCALM-7Bのコードリーディング(基本編)

2023/08/02に公開

概要

CyberAgentさんが公開してくれているLLMモデルであるOpenCALMを動かして単純な質問をした際、どの様なコードが動いているのか読んでみたいと思います。

実行例

google colab proで実行します
※無料版ではメモリが足りずに動きませんでした

ライブラリのインストール

!pip install torch transformers accelerate

コード

以下は
https://huggingface.co/cyberagent/open-calm-7b
に記載されているサンプルコードになります。

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("cyberagent/open-calm-7b", device_map="auto", torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained("cyberagent/open-calm-7b")

inputs = tokenizer("AIによって私達の暮らしは、", return_tensors="pt").to(model.device)
with torch.no_grad():
    tokens = model.generate(
        **inputs,
        max_new_tokens=64,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        repetition_penalty=1.05,
        pad_token_id=tokenizer.pad_token_id,
    )
    
output = tokenizer.decode(tokens[0], skip_special_tokens=True)
print(output)

以下の様な出力がされます。
※途中で切れているように見えるのはmax_new_tokens=64のためです。

AIによって私達の暮らしは、日々大きく変化しています。しかし、人工知能やロボット技術が発達しても、人は人の手で物を創り出したい欲求を捨てられないのではないでしょうか?そして、それはきっとこの先も変わらないと思います。
そのような背景から、当社では製品の企画・開発・販売という一連の流れを、一貫して手掛けることで、よりスピーディーに製品をお客さまにお届け

コードを読む

AutoModelForCausalLM.from_pretrained

transformersを使ってモデルを生成するコードについてみていきます。
transformersとはhuggingface(機械学習モデルのgithubみたいなサイト)からモデルをダウンロードするためのライブラリです。

model = AutoModelForCausalLM.from_pretrained("cyberagent/open-calm-7b", device_map="auto", torch_dtype=torch.float16)

コードを読む前にモデルのサマリーを出力してみます。

from torchinfo import summary
summary(model=model, depth=100)
結果
==========================================================================
Layer (type:depth-idx)                                  Param #
================================================================================
GPTNeoXForCausalLM                                      --
├─GPTNeoXModel: 1-1                                     --
│    └─Embedding: 2-1                                   213,909,504
│    └─Dropout: 2-2                                     --
│    └─ModuleList: 2-3                                  --
│    │    └─GPTNeoXLayer: 3-1                           --
│    │    │    └─LayerNorm: 4-1                         8,192
│    │    │    └─LayerNorm: 4-2                         8,192
│    │    │    └─Dropout: 4-3                           --
│    │    │    └─Dropout: 4-4                           --
│    │    │    └─GPTNeoXAttention: 4-5                  --
│    │    │    │    └─GPTNeoXRotaryEmbedding: 5-1       --
│    │    │    │    └─Linear: 5-2                       50,343,936
│    │    │    │    └─Linear: 5-3                       16,781,312
│    │    │    │    └─Dropout: 5-4                      --
│    │    │    └─GPTNeoXMLP: 4-6                        --
│    │    │    │    └─Linear: 5-5                       67,125,248
│    │    │    │    └─Linear: 5-6                       67,112,960
│    │    │    │    └─GELUActivation: 5-7               --
│    │    └─GPTNeoXLayer: 3-2                           --
│    │    │    └─LayerNorm: 4-7                         8,192
│    │    │    └─LayerNorm: 4-8                         8,192
│    │    │    └─Dropout: 4-9                           --
│    │    │    └─Dropout: 4-10                          --
│    │    │    └─GPTNeoXAttention: 4-11                 --
│    │    │    │    └─GPTNeoXRotaryEmbedding: 5-8       --
│    │    │    │    └─Linear: 5-9                       50,343,936
│    │    │    │    └─Linear: 5-10                      16,781,312
│    │    │    │    └─Dropout: 5-11                     --
│    │    │    └─GPTNeoXMLP: 4-12                       --
│    │    │    │    └─Linear: 5-12                      67,125,248
│    │    │    │    └─Linear: 5-13                      67,112,960
│    │    │    │    └─GELUActivation: 5-14              --
│    │    └─GPTNeoXLayer: 3-3                           --
│    │    │    └─LayerNorm: 4-13                        8,192
│    │    │    └─LayerNorm: 4-14                        8,192
│    │    │    └─Dropout: 4-15                          --
│    │    │    └─Dropout: 4-16                          --
│    │    │    └─GPTNeoXAttention: 4-17                 --
│    │    │    │    └─GPTNeoXRotaryEmbedding: 5-15      --
│    │    │    │    └─Linear: 5-16                      50,343,936
│    │    │    │    └─Linear: 5-17                      16,781,312
│    │    │    │    └─Dropout: 5-18                     --
│    │    │    └─GPTNeoXMLP: 4-18                       --
│    │    │    │    └─Linear: 5-19                      67,125,248
│    │    │    │    └─Linear: 5-20                      67,112,960
│    │    │    │    └─GELUActivation: 5-21              --
│    │    └─GPTNeoXLayer: 3-4                           --
│    │    │    └─LayerNorm: 4-19                        8,192
│    │    │    └─LayerNorm: 4-20                        8,192
│    │    │    └─Dropout: 4-21                          --
│    │    │    └─Dropout: 4-22                          --
│    │    │    └─GPTNeoXAttention: 4-23                 --
│    │    │    │    └─GPTNeoXRotaryEmbedding: 5-22      --
│    │    │    │    └─Linear: 5-23                      50,343,936
│    │    │    │    └─Linear: 5-24                      16,781,312
│    │    │    │    └─Dropout: 5-25                     --
│    │    │    └─GPTNeoXMLP: 4-24                       --
│    │    │    │    └─Linear: 5-26                      67,125,248
│    │    │    │    └─Linear: 5-27                      67,112,960
│    │    │    │    └─GELUActivation: 5-28              --
│    │    └─GPTNeoXLayer: 3-5                           --
│    │    │    └─LayerNorm: 4-25                        8,192
│    │    │    └─LayerNorm: 4-26                        8,192
│    │    │    └─Dropout: 4-27                          --
│    │    │    └─Dropout: 4-28                          --
│    │    │    └─GPTNeoXAttention: 4-29                 --
│    │    │    │    └─GPTNeoXRotaryEmbedding: 5-29      --
│    │    │    │    └─Linear: 5-30                      50,343,936
│    │    │    │    └─Linear: 5-31                      16,781,312
│    │    │    │    └─Dropout: 5-32                     --
│    │    │    └─GPTNeoXMLP: 4-30                       --
│    │    │    │    └─Linear: 5-33                      67,125,248
│    │    │    │    └─Linear: 5-34                      67,112,960
│    │    │    │    └─GELUActivation: 5-35              --
│    │    └─GPTNeoXLayer: 3-6                           --
│    │    │    └─LayerNorm: 4-31                        8,192
│    │    │    └─LayerNorm: 4-32                        8,192
│    │    │    └─Dropout: 4-33                          --
│    │    │    └─Dropout: 4-34                          --
│    │    │    └─GPTNeoXAttention: 4-35                 --
│    │    │    │    └─GPTNeoXRotaryEmbedding: 5-36      --
│    │    │    │    └─Linear: 5-37                      50,343,936
│    │    │    │    └─Linear: 5-38                      16,781,312
│    │    │    │    └─Dropout: 5-39                     --
│    │    │    └─GPTNeoXMLP: 4-36                       --
│    │    │    │    └─Linear: 5-40                      67,125,248
│    │    │    │    └─Linear: 5-41                      67,112,960
│    │    │    │    └─GELUActivation: 5-42              --
│    │    └─GPTNeoXLayer: 3-7                           --
│    │    │    └─LayerNorm: 4-37                        8,192
│    │    │    └─LayerNorm: 4-38                        8,192
│    │    │    └─Dropout: 4-39                          --
│    │    │    └─Dropout: 4-40                          --
│    │    │    └─GPTNeoXAttention: 4-41                 --
│    │    │    │    └─GPTNeoXRotaryEmbedding: 5-43      --
│    │    │    │    └─Linear: 5-44                      50,343,936
│    │    │    │    └─Linear: 5-45                      16,781,312
│    │    │    │    └─Dropout: 5-46                     --
│    │    │    └─GPTNeoXMLP: 4-42                       --
│    │    │    │    └─Linear: 5-47                      67,125,248
│    │    │    │    └─Linear: 5-48                      67,112,960
│    │    │    │    └─GELUActivation: 5-49              --
│    │    └─GPTNeoXLayer: 3-8                           --
│    │    │    └─LayerNorm: 4-43                        8,192
│    │    │    └─LayerNorm: 4-44                        8,192
│    │    │    └─Dropout: 4-45                          --
│    │    │    └─Dropout: 4-46                          --
│    │    │    └─GPTNeoXAttention: 4-47                 --
│    │    │    │    └─GPTNeoXRotaryEmbedding: 5-50      --
│    │    │    │    └─Linear: 5-51                      50,343,936
│    │    │    │    └─Linear: 5-52                      16,781,312
│    │    │    │    └─Dropout: 5-53                     --
│    │    │    └─GPTNeoXMLP: 4-48                       --
│    │    │    │    └─Linear: 5-54                      67,125,248
│    │    │    │    └─Linear: 5-55                      67,112,960
│    │    │    │    └─GELUActivation: 5-56              --
│    │    └─GPTNeoXLayer: 3-9                           --
│    │    │    └─LayerNorm: 4-49                        8,192
│    │    │    └─LayerNorm: 4-50                        8,192
│    │    │    └─Dropout: 4-51                          --
│    │    │    └─Dropout: 4-52                          --
│    │    │    └─GPTNeoXAttention: 4-53                 --
│    │    │    │    └─GPTNeoXRotaryEmbedding: 5-57      --
│    │    │    │    └─Linear: 5-58                      50,343,936
│    │    │    │    └─Linear: 5-59                      16,781,312
│    │    │    │    └─Dropout: 5-60                     --
│    │    │    └─GPTNeoXMLP: 4-54                       --
│    │    │    │    └─Linear: 5-61                      67,125,248
│    │    │    │    └─Linear: 5-62                      67,112,960
│    │    │    │    └─GELUActivation: 5-63              --
│    │    └─GPTNeoXLayer: 3-10                          --
│    │    │    └─LayerNorm: 4-55                        8,192
│    │    │    └─LayerNorm: 4-56                        8,192
│    │    │    └─Dropout: 4-57                          --
│    │    │    └─Dropout: 4-58                          --
│    │    │    └─GPTNeoXAttention: 4-59                 --
│    │    │    │    └─GPTNeoXRotaryEmbedding: 5-64      --
│    │    │    │    └─Linear: 5-65                      50,343,936
│    │    │    │    └─Linear: 5-66                      16,781,312
│    │    │    │    └─Dropout: 5-67                     --
│    │    │    └─GPTNeoXMLP: 4-60                       --
│    │    │    │    └─Linear: 5-68                      67,125,248
│    │    │    │    └─Linear: 5-69                      67,112,960
│    │    │    │    └─GELUActivation: 5-70              --
│    │    └─GPTNeoXLayer: 3-11                          --
│    │    │    └─LayerNorm: 4-61                        8,192
│    │    │    └─LayerNorm: 4-62                        8,192
│    │    │    └─Dropout: 4-63                          --
│    │    │    └─Dropout: 4-64                          --
│    │    │    └─GPTNeoXAttention: 4-65                 --
│    │    │    │    └─GPTNeoXRotaryEmbedding: 5-71      --
│    │    │    │    └─Linear: 5-72                      50,343,936
│    │    │    │    └─Linear: 5-73                      16,781,312
│    │    │    │    └─Dropout: 5-74                     --
│    │    │    └─GPTNeoXMLP: 4-66                       --
│    │    │    │    └─Linear: 5-75                      67,125,248
│    │    │    │    └─Linear: 5-76                      67,112,960
│    │    │    │    └─GELUActivation: 5-77              --
│    │    └─GPTNeoXLayer: 3-12                          --
│    │    │    └─LayerNorm: 4-67                        8,192
│    │    │    └─LayerNorm: 4-68                        8,192
│    │    │    └─Dropout: 4-69                          --
│    │    │    └─Dropout: 4-70                          --
│    │    │    └─GPTNeoXAttention: 4-71                 --
│    │    │    │    └─GPTNeoXRotaryEmbedding: 5-78      --
│    │    │    │    └─Linear: 5-79                      50,343,936
│    │    │    │    └─Linear: 5-80                      16,781,312
│    │    │    │    └─Dropout: 5-81                     --
│    │    │    └─GPTNeoXMLP: 4-72                       --
│    │    │    │    └─Linear: 5-82                      67,125,248
│    │    │    │    └─Linear: 5-83                      67,112,960
│    │    │    │    └─GELUActivation: 5-84              --
│    │    └─GPTNeoXLayer: 3-13                          --
│    │    │    └─LayerNorm: 4-73                        8,192
│    │    │    └─LayerNorm: 4-74                        8,192
│    │    │    └─Dropout: 4-75                          --
│    │    │    └─Dropout: 4-76                          --
│    │    │    └─GPTNeoXAttention: 4-77                 --
│    │    │    │    └─GPTNeoXRotaryEmbedding: 5-85      --
│    │    │    │    └─Linear: 5-86                      50,343,936
│    │    │    │    └─Linear: 5-87                      16,781,312
│    │    │    │    └─Dropout: 5-88                     --
│    │    │    └─GPTNeoXMLP: 4-78                       --
│    │    │    │    └─Linear: 5-89                      67,125,248
│    │    │    │    └─Linear: 5-90                      67,112,960
│    │    │    │    └─GELUActivation: 5-91              --
│    │    └─GPTNeoXLayer: 3-14                          --
│    │    │    └─LayerNorm: 4-79                        8,192
│    │    │    └─LayerNorm: 4-80                        8,192
│    │    │    └─Dropout: 4-81                          --
│    │    │    └─Dropout: 4-82                          --
│    │    │    └─GPTNeoXAttention: 4-83                 --
│    │    │    │    └─GPTNeoXRotaryEmbedding: 5-92      --
│    │    │    │    └─Linear: 5-93                      50,343,936
│    │    │    │    └─Linear: 5-94                      16,781,312
│    │    │    │    └─Dropout: 5-95                     --
│    │    │    └─GPTNeoXMLP: 4-84                       --
│    │    │    │    └─Linear: 5-96                      67,125,248
│    │    │    │    └─Linear: 5-97                      67,112,960
│    │    │    │    └─GELUActivation: 5-98              --
│    │    └─GPTNeoXLayer: 3-15                          --
│    │    │    └─LayerNorm: 4-85                        8,192
│    │    │    └─LayerNorm: 4-86                        8,192
│    │    │    └─Dropout: 4-87                          --
│    │    │    └─Dropout: 4-88                          --
│    │    │    └─GPTNeoXAttention: 4-89                 --
│    │    │    │    └─GPTNeoXRotaryEmbedding: 5-99      --
│    │    │    │    └─Linear: 5-100                     50,343,936
│    │    │    │    └─Linear: 5-101                     16,781,312
│    │    │    │    └─Dropout: 5-102                    --
│    │    │    └─GPTNeoXMLP: 4-90                       --
│    │    │    │    └─Linear: 5-103                     67,125,248
│    │    │    │    └─Linear: 5-104                     67,112,960
│    │    │    │    └─GELUActivation: 5-105             --
│    │    └─GPTNeoXLayer: 3-16                          --
│    │    │    └─LayerNorm: 4-91                        8,192
│    │    │    └─LayerNorm: 4-92                        8,192
│    │    │    └─Dropout: 4-93                          --
│    │    │    └─Dropout: 4-94                          --
│    │    │    └─GPTNeoXAttention: 4-95                 --
│    │    │    │    └─GPTNeoXRotaryEmbedding: 5-106     --
│    │    │    │    └─Linear: 5-107                     50,343,936
│    │    │    │    └─Linear: 5-108                     16,781,312
│    │    │    │    └─Dropout: 5-109                    --
│    │    │    └─GPTNeoXMLP: 4-96                       --
│    │    │    │    └─Linear: 5-110                     67,125,248
│    │    │    │    └─Linear: 5-111                     67,112,960
│    │    │    │    └─GELUActivation: 5-112             --
│    │    └─GPTNeoXLayer: 3-17                          --
│    │    │    └─LayerNorm: 4-97                        8,192
│    │    │    └─LayerNorm: 4-98                        8,192
│    │    │    └─Dropout: 4-99                          --
│    │    │    └─Dropout: 4-100                         --
│    │    │    └─GPTNeoXAttention: 4-101                --
│    │    │    │    └─GPTNeoXRotaryEmbedding: 5-113     --
│    │    │    │    └─Linear: 5-114                     50,343,936
│    │    │    │    └─Linear: 5-115                     16,781,312
│    │    │    │    └─Dropout: 5-116                    --
│    │    │    └─GPTNeoXMLP: 4-102                      --
│    │    │    │    └─Linear: 5-117                     67,125,248
│    │    │    │    └─Linear: 5-118                     67,112,960
│    │    │    │    └─GELUActivation: 5-119             --
│    │    └─GPTNeoXLayer: 3-18                          --
│    │    │    └─LayerNorm: 4-103                       8,192
│    │    │    └─LayerNorm: 4-104                       8,192
│    │    │    └─Dropout: 4-105                         --
│    │    │    └─Dropout: 4-106                         --
│    │    │    └─GPTNeoXAttention: 4-107                --
│    │    │    │    └─GPTNeoXRotaryEmbedding: 5-120     --
│    │    │    │    └─Linear: 5-121                     50,343,936
│    │    │    │    └─Linear: 5-122                     16,781,312
│    │    │    │    └─Dropout: 5-123                    --
│    │    │    └─GPTNeoXMLP: 4-108                      --
│    │    │    │    └─Linear: 5-124                     67,125,248
│    │    │    │    └─Linear: 5-125                     67,112,960
│    │    │    │    └─GELUActivation: 5-126             --
│    │    └─GPTNeoXLayer: 3-19                          --
│    │    │    └─LayerNorm: 4-109                       8,192
│    │    │    └─LayerNorm: 4-110                       8,192
│    │    │    └─Dropout: 4-111                         --
│    │    │    └─Dropout: 4-112                         --
│    │    │    └─GPTNeoXAttention: 4-113                --
│    │    │    │    └─GPTNeoXRotaryEmbedding: 5-127     --
│    │    │    │    └─Linear: 5-128                     50,343,936
│    │    │    │    └─Linear: 5-129                     16,781,312
│    │    │    │    └─Dropout: 5-130                    --
│    │    │    └─GPTNeoXMLP: 4-114                      --
│    │    │    │    └─Linear: 5-131                     67,125,248
│    │    │    │    └─Linear: 5-132                     67,112,960
│    │    │    │    └─GELUActivation: 5-133             --
│    │    └─GPTNeoXLayer: 3-20                          --
│    │    │    └─LayerNorm: 4-115                       8,192
│    │    │    └─LayerNorm: 4-116                       8,192
│    │    │    └─Dropout: 4-117                         --
│    │    │    └─Dropout: 4-118                         --
│    │    │    └─GPTNeoXAttention: 4-119                --
│    │    │    │    └─GPTNeoXRotaryEmbedding: 5-134     --
│    │    │    │    └─Linear: 5-135                     50,343,936
│    │    │    │    └─Linear: 5-136                     16,781,312
│    │    │    │    └─Dropout: 5-137                    --
│    │    │    └─GPTNeoXMLP: 4-120                      --
│    │    │    │    └─Linear: 5-138                     67,125,248
│    │    │    │    └─Linear: 5-139                     67,112,960
│    │    │    │    └─GELUActivation: 5-140             --
│    │    └─GPTNeoXLayer: 3-21                          --
│    │    │    └─LayerNorm: 4-121                       8,192
│    │    │    └─LayerNorm: 4-122                       8,192
│    │    │    └─Dropout: 4-123                         --
│    │    │    └─Dropout: 4-124                         --
│    │    │    └─GPTNeoXAttention: 4-125                --
│    │    │    │    └─GPTNeoXRotaryEmbedding: 5-141     --
│    │    │    │    └─Linear: 5-142                     50,343,936
│    │    │    │    └─Linear: 5-143                     16,781,312
│    │    │    │    └─Dropout: 5-144                    --
│    │    │    └─GPTNeoXMLP: 4-126                      --
│    │    │    │    └─Linear: 5-145                     67,125,248
│    │    │    │    └─Linear: 5-146                     67,112,960
│    │    │    │    └─GELUActivation: 5-147             --
│    │    └─GPTNeoXLayer: 3-22                          --
│    │    │    └─LayerNorm: 4-127                       8,192
│    │    │    └─LayerNorm: 4-128                       8,192
│    │    │    └─Dropout: 4-129                         --
│    │    │    └─Dropout: 4-130                         --
│    │    │    └─GPTNeoXAttention: 4-131                --
│    │    │    │    └─GPTNeoXRotaryEmbedding: 5-148     --
│    │    │    │    └─Linear: 5-149                     50,343,936
│    │    │    │    └─Linear: 5-150                     16,781,312
│    │    │    │    └─Dropout: 5-151                    --
│    │    │    └─GPTNeoXMLP: 4-132                      --
│    │    │    │    └─Linear: 5-152                     67,125,248
│    │    │    │    └─Linear: 5-153                     67,112,960
│    │    │    │    └─GELUActivation: 5-154             --
│    │    └─GPTNeoXLayer: 3-23                          --
│    │    │    └─LayerNorm: 4-133                       8,192
│    │    │    └─LayerNorm: 4-134                       8,192
│    │    │    └─Dropout: 4-135                         --
│    │    │    └─Dropout: 4-136                         --
│    │    │    └─GPTNeoXAttention: 4-137                --
│    │    │    │    └─GPTNeoXRotaryEmbedding: 5-155     --
│    │    │    │    └─Linear: 5-156                     50,343,936
│    │    │    │    └─Linear: 5-157                     16,781,312
│    │    │    │    └─Dropout: 5-158                    --
│    │    │    └─GPTNeoXMLP: 4-138                      --
│    │    │    │    └─Linear: 5-159                     67,125,248
│    │    │    │    └─Linear: 5-160                     67,112,960
│    │    │    │    └─GELUActivation: 5-161             --
│    │    └─GPTNeoXLayer: 3-24                          --
│    │    │    └─LayerNorm: 4-139                       8,192
│    │    │    └─LayerNorm: 4-140                       8,192
│    │    │    └─Dropout: 4-141                         --
│    │    │    └─Dropout: 4-142                         --
│    │    │    └─GPTNeoXAttention: 4-143                --
│    │    │    │    └─GPTNeoXRotaryEmbedding: 5-162     --
│    │    │    │    └─Linear: 5-163                     50,343,936
│    │    │    │    └─Linear: 5-164                     16,781,312
│    │    │    │    └─Dropout: 5-165                    --
│    │    │    └─GPTNeoXMLP: 4-144                      --
│    │    │    │    └─Linear: 5-166                     67,125,248
│    │    │    │    └─Linear: 5-167                     67,112,960
│    │    │    │    └─GELUActivation: 5-168             --
│    │    └─GPTNeoXLayer: 3-25                          --
│    │    │    └─LayerNorm: 4-145                       8,192
│    │    │    └─LayerNorm: 4-146                       8,192
│    │    │    └─Dropout: 4-147                         --
│    │    │    └─Dropout: 4-148                         --
│    │    │    └─GPTNeoXAttention: 4-149                --
│    │    │    │    └─GPTNeoXRotaryEmbedding: 5-169     --
│    │    │    │    └─Linear: 5-170                     50,343,936
│    │    │    │    └─Linear: 5-171                     16,781,312
│    │    │    │    └─Dropout: 5-172                    --
│    │    │    └─GPTNeoXMLP: 4-150                      --
│    │    │    │    └─Linear: 5-173                     67,125,248
│    │    │    │    └─Linear: 5-174                     67,112,960
│    │    │    │    └─GELUActivation: 5-175             --
│    │    └─GPTNeoXLayer: 3-26                          --
│    │    │    └─LayerNorm: 4-151                       8,192
│    │    │    └─LayerNorm: 4-152                       8,192
│    │    │    └─Dropout: 4-153                         --
│    │    │    └─Dropout: 4-154                         --
│    │    │    └─GPTNeoXAttention: 4-155                --
│    │    │    │    └─GPTNeoXRotaryEmbedding: 5-176     --
│    │    │    │    └─Linear: 5-177                     50,343,936
│    │    │    │    └─Linear: 5-178                     16,781,312
│    │    │    │    └─Dropout: 5-179                    --
│    │    │    └─GPTNeoXMLP: 4-156                      --
│    │    │    │    └─Linear: 5-180                     67,125,248
│    │    │    │    └─Linear: 5-181                     67,112,960
│    │    │    │    └─GELUActivation: 5-182             --
│    │    └─GPTNeoXLayer: 3-27                          --
│    │    │    └─LayerNorm: 4-157                       8,192
│    │    │    └─LayerNorm: 4-158                       8,192
│    │    │    └─Dropout: 4-159                         --
│    │    │    └─Dropout: 4-160                         --
│    │    │    └─GPTNeoXAttention: 4-161                --
│    │    │    │    └─GPTNeoXRotaryEmbedding: 5-183     --
│    │    │    │    └─Linear: 5-184                     50,343,936
│    │    │    │    └─Linear: 5-185                     16,781,312
│    │    │    │    └─Dropout: 5-186                    --
│    │    │    └─GPTNeoXMLP: 4-162                      --
│    │    │    │    └─Linear: 5-187                     67,125,248
│    │    │    │    └─Linear: 5-188                     67,112,960
│    │    │    │    └─GELUActivation: 5-189             --
│    │    └─GPTNeoXLayer: 3-28                          --
│    │    │    └─LayerNorm: 4-163                       8,192
│    │    │    └─LayerNorm: 4-164                       8,192
│    │    │    └─Dropout: 4-165                         --
│    │    │    └─Dropout: 4-166                         --
│    │    │    └─GPTNeoXAttention: 4-167                --
│    │    │    │    └─GPTNeoXRotaryEmbedding: 5-190     --
│    │    │    │    └─Linear: 5-191                     50,343,936
│    │    │    │    └─Linear: 5-192                     16,781,312
│    │    │    │    └─Dropout: 5-193                    --
│    │    │    └─GPTNeoXMLP: 4-168                      --
│    │    │    │    └─Linear: 5-194                     67,125,248
│    │    │    │    └─Linear: 5-195                     67,112,960
│    │    │    │    └─GELUActivation: 5-196             --
│    │    └─GPTNeoXLayer: 3-29                          --
│    │    │    └─LayerNorm: 4-169                       8,192
│    │    │    └─LayerNorm: 4-170                       8,192
│    │    │    └─Dropout: 4-171                         --
│    │    │    └─Dropout: 4-172                         --
│    │    │    └─GPTNeoXAttention: 4-173                --
│    │    │    │    └─GPTNeoXRotaryEmbedding: 5-197     --
│    │    │    │    └─Linear: 5-198                     50,343,936
│    │    │    │    └─Linear: 5-199                     16,781,312
│    │    │    │    └─Dropout: 5-200                    --
│    │    │    └─GPTNeoXMLP: 4-174                      --
│    │    │    │    └─Linear: 5-201                     67,125,248
│    │    │    │    └─Linear: 5-202                     67,112,960
│    │    │    │    └─GELUActivation: 5-203             --
│    │    └─GPTNeoXLayer: 3-30                          --
│    │    │    └─LayerNorm: 4-175                       8,192
│    │    │    └─LayerNorm: 4-176                       8,192
│    │    │    └─Dropout: 4-177                         --
│    │    │    └─Dropout: 4-178                         --
│    │    │    └─GPTNeoXAttention: 4-179                --
│    │    │    │    └─GPTNeoXRotaryEmbedding: 5-204     --
│    │    │    │    └─Linear: 5-205                     50,343,936
│    │    │    │    └─Linear: 5-206                     16,781,312
│    │    │    │    └─Dropout: 5-207                    --
│    │    │    └─GPTNeoXMLP: 4-180                      --
│    │    │    │    └─Linear: 5-208                     67,125,248
│    │    │    │    └─Linear: 5-209                     67,112,960
│    │    │    │    └─GELUActivation: 5-210             --
│    │    └─GPTNeoXLayer: 3-31                          --
│    │    │    └─LayerNorm: 4-181                       8,192
│    │    │    └─LayerNorm: 4-182                       8,192
│    │    │    └─Dropout: 4-183                         --
│    │    │    └─Dropout: 4-184                         --
│    │    │    └─GPTNeoXAttention: 4-185                --
│    │    │    │    └─GPTNeoXRotaryEmbedding: 5-211     --
│    │    │    │    └─Linear: 5-212                     50,343,936
│    │    │    │    └─Linear: 5-213                     16,781,312
│    │    │    │    └─Dropout: 5-214                    --
│    │    │    └─GPTNeoXMLP: 4-186                      --
│    │    │    │    └─Linear: 5-215                     67,125,248
│    │    │    │    └─Linear: 5-216                     67,112,960
│    │    │    │    └─GELUActivation: 5-217             --
│    │    └─GPTNeoXLayer: 3-32                          --
│    │    │    └─LayerNorm: 4-187                       8,192
│    │    │    └─LayerNorm: 4-188                       8,192
│    │    │    └─Dropout: 4-189                         --
│    │    │    └─Dropout: 4-190                         --
│    │    │    └─GPTNeoXAttention: 4-191                --
│    │    │    │    └─GPTNeoXRotaryEmbedding: 5-218     --
│    │    │    │    └─Linear: 5-219                     50,343,936
│    │    │    │    └─Linear: 5-220                     16,781,312
│    │    │    │    └─Dropout: 5-221                    --
│    │    │    └─GPTNeoXMLP: 4-192                      --
│    │    │    │    └─Linear: 5-222                     67,125,248
│    │    │    │    └─Linear: 5-223                     67,112,960
│    │    │    │    └─GELUActivation: 5-224             --
│    └─LayerNorm: 2-4                                   8,192
├─Linear: 1-2                                           213,909,504
================================================================================
Total params: 6,871,982,080
Trainable params: 6,871,982,080
Non-trainable params: 0
================================================================================

GPTNeoXForCausalLMというアーキテクチャが使われているようです。
https://huggingface.co/docs/transformers/model_doc/gpt_neox

68億のパラメータであることもわかります。

Total params: 6,871,982,080

AutoModelForCausalLMについて見ていきます。
https://github.com/huggingface/transformers/blob/main/src/transformers/models/auto/modeling_auto.py#L1192-L1196

auto_class_updateがクラスを生成しているようです。
https://github.com/huggingface/transformers/blob/main/src/transformers/models/auto/auto_factory.py#L592-L624

サンプルコードで叩くのはここでセットしているクラスメソッドです。

cls.from_pretrained = classmethod(from_pretrained)

親クラスのメソッドをコピーしてから調整しているようです。

from_pretrained = copy_func(_BaseAutoModelClass.from_pretrained)

from_pretrainedは学習済みのモデルを返すので重要です。
pretrained_model_name_or_path
には"cyberagent/open-calm-7b"が入ります。
https://github.com/huggingface/transformers/blob/main/src/transformers/models/auto/auto_factory.py#L438-L517

結果このブロックから学習済みモデルを返しています。

        elif type(config) in cls._model_mapping.keys():
            model_class = _get_model_class(config, cls._model_mapping)
            return model_class.from_pretrained(
                pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
            )

model_class.from_pretrained
は以下のメソッドになります。
長すぎて読みきれないです。
https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L1957-L3008

ただ結局返しているのはmodel(GPTNeoXForCausalLM)になります。

重みは
https://huggingface.co/cyberagent/open-calm-7b/tree/main
からダウンロードしてモデルにload_state_dictされているようです。
※load_state_dictはpytorchの仕組みでモデルに対して重みをロードするもの。

AutoTokenizer.from_pretrained

tokenizerを生成するコードについてみていきます。
tokenizerは文字列をTokenに分割し、それぞれのIDを返してくれるものになります。

tokenizer = AutoTokenizer.from_pretrained("cyberagent/open-calm-7b")

tokenizerをprintすると以下になります。

GPTNeoXTokenizerFast(name_or_path='cyberagent/open-calm-7b', vocab_size=52000, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|padding|>'}, clean_up_tokenization_spaces=True)

文字列をトークン化すると

inputs = tokenizer("AIによって私達の暮らしは、", return_tensors="pt").to(model.device)
print(inputs)

以下のように変換されます。

{'input_ids': tensor([[ 4215,   930, 18030, 16205,   257,   245]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]], device='cuda:0')}

AutoTokenizer.from_pretrainedを使用してtokenizerを取得すると、
以下のメソッドが実行されるようです。こちらはあまり掘り下げないでおきます。
https://github.com/huggingface/transformers/blob/main/src/transformers/models/auto/tokenization_auto.py#L547-L755

with torch.no_grad():

こちらはpytorchの機能で、ブロック内では勾配の計算をしないというものになります。
学習しないときはこのブロックに入れることで処理が軽くなります。

model.generate

ここで入力されたトークンに対して答えとなるトークンを推論しています。

    tokens = model.generate(
        **inputs,
        max_new_tokens=64,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        repetition_penalty=1.05,
        pad_token_id=tokenizer.pad_token_id,
    )

普通はpytorchだと以下のように呼び出すのがお決まりなのですが、transformers経由だと少し勝手が違うみたいです。

outputs = model(inputs)

generateコードを見てみます。こちらも長いです。
https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py#L1186-L1843

最終的にこのブロックに入ります。

        elif is_sample_gen_mode:
            # 11. prepare logits warper
            logits_warper = self._get_logits_warper(generation_config)

            # 12. expand input_ids with `num_return_sequences` additional sequences per batch
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids=input_ids,
                expand_size=generation_config.num_return_sequences,
                is_encoder_decoder=self.config.is_encoder_decoder,
                **model_kwargs,
            )

            # 13. run sample
            return self.sample(
                input_ids,
                logits_processor=logits_processor,
                logits_warper=logits_warper,
                stopping_criteria=stopping_criteria,
                pad_token_id=generation_config.pad_token_id,
                eos_token_id=generation_config.eos_token_id,
                output_scores=generation_config.output_scores,
                return_dict_in_generate=generation_config.return_dict_in_generate,
                synced_gpus=synced_gpus,
                streamer=streamer,
                **model_kwargs,
            )

実際のモデルに推論させているのはここです。最大64回通ります。
max_new_tokens=64に対応しているのだと思います。それより短い場合もありそうです。
https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py#L2736-L2742

tokenizer.decode

最後は推論されたトークンのIDリストを日本語に変換して終了です。

output = tokenizer.decode(tokens[0], skip_special_tokens=True)

まとめ

transformersのコードは結構マッチョで読みづらかったです。

今後はGPTNeoXForCausalLMの中身、
fine-turning時にどのように動作しているのか等を調査したいと思います。

Discussion