Zenn
👌

llama3.1 405B モデルをA100 * 4で推論してみた

2024/07/25に公開

現状の検証事項

推論ベースの検証結果です。

モデル名 検証状況
Meta-Llama-3.1-405B-Instruct ⭕️ bitsandbytesで4bit
Meta-Llama-3.1-405B-Instruct-FP8 ❌ FP8で回すためには、A100ではなくH100が必要
hugging-quants/Meta-Llama-3.1-405B-Instruct-GPTQ-INT4 ⭕️

利用時に関して

今回はHFモデルを使用します。利用にはHFで申請が必要です。

特徴

コンテキスト長が128K tokensに対応しています。

引用: https://scontent-nrt1-2.xx.fbcdn.net/v/t39.2365-6/452387774_1036916434819166_4173978747091533306_n.pdf?_nc_cat=104&ccb=1-7&_nc_sid=3c67a6&_nc_ohc=t6egZJ8QdI4Q7kNvgHY23Q4&_nc_ht=scontent-nrt1-2.xx&oh=00_AYC62w8L7D3YeXR3KpK9_EAO_6EHAT3VetM1pQtodk2XCA&oe=66A60A8D

ベンチマーク

論文のベンチマークの結果も高く出ています。

モデルアーキテクチャー

モデルのダウンロード

LOCAL_PATH = "保存する場所"
HF_TOKEN = "トークンをここに入れる。"

download_path = snapshot_download(repo_id="meta-llama/Meta-Llama-3.1-405B-Instruct",
                                  local_dir=LOCAL_PATH,
                                  token=HF_TOKEN,
                                  local_dir_use_symlinks=False)

環境

今回は(1node) A100 4枚の環境で試してみます。さすがに A100 4枚では乗り切らないので4bit量子化しました。

筆者のライブラリ/ハード環境

databricks環境(実際はnotebookですが、通常にpythonを動かすように記載しています。)

cuda121
torch 2.3.0+cu121
lash-attn 2.5.8
transformers 4.43.1

読み込み

pip install --upgrade pip
pip install bitsandbytes
pip install --upgrade transformers
import torch
from transformers import AutoTokenizer,AutoModelForCausalLM,BitsAndBytesConfig

LOCAL_PATH = "/Path/to/llama3/405B/405B/" # 保存した場所を指定


bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(LOCAL_PATH)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(LOCAL_PATH,
                                             quantization_config=bnb_config,
                                             torch_dtype=torch.bfloat16,
                                             device_map='auto')

シングルノードなので、device_map='auto'でモデルをパラレルします。

↓下記に気をつける。

Inference: For inference, bnb_4bit_quant_type does not have a huge impact on the performance. However for consistency with the model's weights, make sure you use the same bnb_4bit_compute_dtype and torch_dtype arguments.

モデルの構造

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 16384)
    (layers): ModuleList(
      (0-125): 126 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear4bit(in_features=16384, out_features=16384, bias=False)
          (k_proj): Linear4bit(in_features=16384, out_features=2048, bias=False)
          (v_proj): Linear4bit(in_features=16384, out_features=2048, bias=False)
          (o_proj): Linear4bit(in_features=16384, out_features=16384, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=16384, out_features=53248, bias=False)
          (up_proj): Linear4bit(in_features=16384, out_features=53248, bias=False)
          (down_proj): Linear4bit(in_features=53248, out_features=16384, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (lm_head): Linear(in_features=16384, out_features=128256, bias=False)
)

推論

prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nuse Japanese.<|eot_id|><|start_header_id|>user<|end_header_id|>\n3 + 3はいくらになりますか?<|eot_id|><|start_header_id|>assistant<|end_header_id|>"

inputs = tokenizer(prompt, return_tensors="pt").to("cuda")


with torch.no_grad():

    pre_generate_ids = model.generate(do_sample=False,
                                      input_ids=inputs["input_ids"].to("cuda"),
                                      max_length= 256,
                                      eos_token_id=tokenizer.convert_tokens_to_ids("<|eot_id|>"))
    
    pre_returned = tokenizer.batch_decode(pre_generate_ids, skip_special_tokens = False)[0]


    print(pre_returned)

結果

<|begin_of_text|><|start_header_id|>system<|end_header_id|>
use Japanese.<|eot_id|><|start_header_id|>user<|end_header_id|>はじましてこんにちは<|eot_id|><|start_header_id|>assistant<|end_header_id|>

こんにちは!はじめまして!お元気ですか?何かお話したいことありますか?<|eot_id|>

うまく読み込めました!!

meta-llama/Meta-Llama-3.1-405B-Instruct-FP8

公式が出しているFP8で量子化されたモデルを検証してみます。

結論: A100ではダメでした..

FP8 quantized models is only supported on GPUs with compute capability >= 9.0 (e.g H100)
pip install fbgemm-gpu

下記のエラーが出たので、accelerateをアップグレードします
FP8 quantized model requires accelerate > 0.32.1 (pip install --upgrade accelerate)

pip install --upgrade accelerate

hugging-quants/Meta-Llama-3.1-405B-Instruct-GPTQ-INT4

GPTQでint4量子化されていたものがあったので今回はこれを利用してみます。

pip install optimum
pip install auto-gptq
import torch
from transformers import AutoTokenizer,AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(LOCAL_PATH,
                                             torch_dtype=torch.bfloat16,
                                             device_map='auto')

modelの構造

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 16384)
    (layers): ModuleList(
      (0-125): 126 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (rotary_emb): LlamaRotaryEmbedding()
          (k_proj): QuantLinear()
          (o_proj): QuantLinear()
          (q_proj): QuantLinear()
          (v_proj): QuantLinear()
        )
        (mlp): LlamaMLP(
          (act_fn): SiLU()
          (down_proj): QuantLinear()
          (gate_proj): QuantLinear()
          (up_proj): QuantLinear()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (lm_head): Linear(in_features=16384, out_features=128256, bias=False)
)

GPU使用率

|   0  NVIDIA A100 80GB PCIe          Off | 00000001:00:00.0 Off |                    0 |
| N/A   51C    P0              74W / 300W |  56141MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA A100 80GB PCIe          Off | 00000002:00:00.0 Off |                    0 |
| N/A   55C    P0              86W / 300W |  61731MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   2  NVIDIA A100 80GB PCIe          Off | 00000003:00:00.0 Off |                    0 |
| N/A   56C    P0             314W / 300W |  61731MiB / 81920MiB |    100%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   3  NVIDIA A100 80GB PCIe          Off | 00000004:00:00.0 Off |                    0 |
| N/A   52C    P0              85W / 300W |  63349MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

実行結果

<|begin_of_text|><|start_header_id|>system<|end_header_id|>
use Japanese.<|eot_id|><|start_header_id|>user<|end_header_id|>はじましてこんにちは<|eot_id|><|start_header_id|>assistant<|end_header_id|>

こんにちは!はじめまして!お元気ですか?<|eot_id|>

エラーが発生した場合の対処

(1)

ValueError: rope_scaling must be a dictionary with two fields, type and factor, got {'factor': 8.0, 'low_freq_factor': 1.0, 'high_freq_factor': 4.0, 'original_max_position_embeddings': 8192, 'rope_type': 'llama3'}

→ transformersのupgradeを実施

pip install --upgrade transformers

transformers 4.43.1 で解決

Discussion

ログインするとコメントできます