💬

MPT-7B 日本語味見

2023/05/06に公開

https://note.com/npaka/n/nf442fc9f9c8d

npaka 先生, ありがとうございます.

weight サイズは fp16 で 13 GB くらいです.
Tesla P100(16 GB) に収まります.

Ryzen9 3950X + Tesla P100 で動かしてみます.

import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM

device = "cuda"

tokenizer = AutoTokenizer.from_pretrained(
    "mosaicml/mpt-7b-chat"
)

print("model...")
model = AutoModelForCausalLM.from_pretrained(
    "mosaicml/mpt-7b-chat",
    torch_dtype=torch.float16,
    trust_remote_code=True
).to(device)
print("model done...")

prompt = "<human>: 東京について教えて。\n<bot>:"

print("tokenize...")
inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
print("tokenize done...")
input_length = inputs.input_ids.shape[1]
print("input_len", input_length)

print("run")
outputs = model.generate(
    **inputs,
    max_new_tokens=128,
    do_sample=True,
    temperature=0.7,
    top_p=0.7,
    top_k=50,
    return_dict_in_generate=True
)
token = outputs.sequences[0, input_length:]
output_str = tokenizer.decode(token)

print("output :", output_str)
output :  東京は、日本の首都です。日本は、1947年に国立国会が設立された後、国際的に重要な地位を取った。東京は、世界の中心的な都市で、人口は3700000人、都市面積は2300平方キロメートル、文化、経済、政治、仕事のための都市です
。東京は、世界の最大の金属市場、世界の

ほー, とりあえず日本語はある程度いけそうっぽね.
(中身はいろいろ間違っているけど)

weight 変換おそい?

model のロードと (fp16)変換がめちゃ遅で, Ryzen9 3950x で動かしても 5~10 分くらいかかりました...

attention.py:148: UserWarning: Using `attn_impl: torch`. If your model does not use `alibi` or `prefix_lm` we recommend using `attn_impl: flash` otherwise we recommend using `attn_impl: triton`.
  warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')

Python プロセスは single thread で動いていたので, なにか attention 周りのモデル変換に時間かかっているのかしらん.
(他の transformers で fp16 変換しつつロードする LLM model も, ロードと変換に遅かった気もしなくはないが...)

Tesla P100 なので torch.compile も使えず...

fp16 化したあとの weight を保存するようにするか, そのうち MLC-LLM や CTranslate2 あたりで precompile して高速化に期待でしょうか.

3090 で検証

Threadripper + 3090 で, bfloat16 で再度試して計測しました.

model load. secs =  215.182288646698

output :  東京は、日本の首都で、世界的に重要な都市です。東京は、古くから知られた地方で、日本の文化と伝統を蕎明らしています。東京は、世界の文化的な中心となっています。東京は、文化的に豊かで、食事的に豊富で、人文的に豊かで、自然的に豊かです。東京
run done. secs =  34.78921914100647

model load は 4 分くらい, 実行は 34 秒でした.

config.attn_config['attn_impl'] = 'triton'

は, flash-attn がビルドできなかった(CUDA version mismatch)であきらめ.
flash-attn がいけるともうちょっと高速化されるかもです.

max_new_tokens=128 なので, 入力 prompt 19 tokens も考慮すると

(128+19) tokens/ 34 sec = 4.32 tokens/sec

なんとか許容できるレベルでしょうか...

license あやしい?

Apache 2.0 ですが, 特に story writer で, それええんか? との疑惑は出てます.

https://twitter.com/alexjc/status/1654752379533684736

しばし様子見したほうがよさそうです.

Discussion