🎃

Llama3 7bモデルで推論

2024/04/19に公開2

llama3が登場しましたね。
8Bのモデルと70Bのモデルが登場していますね。

事前学習を行なったもの→ https://huggingface.co/meta-llama/Meta-Llama-3-8B
Instruction Tuning→ https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct

https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama

注意事項

Llama 3 ベースのモデルは名前の先頭に "Llama 3" を含めないといけないようです。

機能面


8Bのモデルでもなかなかのベンチマークの結果が出ているようですね。

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(

wte(word token embedding)はかなり強化されていそうです。

利用時に関して

huggingface版のモデルを使用したい場合は、https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct
ページにて、申請する必要があります。申請するとメールが届きます。

ローカルに保存する。

LOCAL_PATH = "保存する場所"
HF_TOKEN = "トークンをここに入れる。"

download_path = snapshot_download(repo_id="meta-llama/Meta-Llama-3-8B-Instruct",
                                  local_dir=LOCAL_PATH,
                                  token=HF_TOKEN,
                                  local_dir_use_symlinks=False)

モデルとtokenizerのロード

from transformers import AutoTokenizer,AutoModelForCausalLM

LOCAL_PATH = "保存してる場所"

tokenizer = AutoTokenizer.from_pretrained(LOCAL_PATH)
model = AutoModelForCausalLM.from_pretrained(LOCAL_PATH, 
                                             device_map='auto',
                                             torch_dtype="auto")

推論

プロンプトの与え方は下記に記載があります。
https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3

<|begin_of_text|><|start_header_id|>system<|end_header_id|>
{{ system_prompt }}<|eot_id|><|start_header_id|>user<|end_header_id|>
{{ user_message }}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

こんな感じで与えてくださいと書いてある。

prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nuse Japanese.<|eot_id|><|start_header_id|>user<|end_header_id|>\n3 + 3はいくらになりますか?<|eot_id|><|start_header_id|>assistant<|end_header_id|>"

inputs = tokenizer(prompt, return_tensors="pt").to("cuda")


with torch.no_grad():

    pre_generate_ids = model.generate(do_sample=True,
                                      temperature=0.7,
                                      top_p=0.9,
                                      input_ids=inputs["input_ids"].to("cuda"),
                                      max_length= 256,
                                      eos_token_id=tokenizer.convert_tokens_to_ids("<|eot_id|>"))
    
    pre_returned = tokenizer.batch_decode(pre_generate_ids, skip_special_tokens = False)[0]


    print(pre_returned)
回答
3 + 3は6になります!<|eot_id|>

Discussion

yoichiiyoichii

記事ありがとうございます! prompt の与え方について

Newlines (0x0A) are part of the prompt format

なので改行を含めた方が良いと思います

timonekotimoneko

ご指摘ありがとうございます。prompt には'\n'を追加しました。
結果は一応同じでしたが、従うのが無難ですね。

<|begin_of_text|><|start_header_id|>system<|end_header_id|>
use Japanese.<|eot_id|><|start_header_id|>user<|end_header_id|>
3 + 3はいくらになりますか?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

3 + 3は6になります!<|eot_id|>