MPT-7B 日本語味見
npaka 先生, ありがとうございます.
weight サイズは fp16 で 13 GB くらいです.
Tesla P100(16 GB) に収まります.
Ryzen9 3950X + Tesla P100 で動かしてみます.
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained(
"mosaicml/mpt-7b-chat"
)
print("model...")
model = AutoModelForCausalLM.from_pretrained(
"mosaicml/mpt-7b-chat",
torch_dtype=torch.float16,
trust_remote_code=True
).to(device)
print("model done...")
prompt = "<human>: 東京について教えて。\n<bot>:"
print("tokenize...")
inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
print("tokenize done...")
input_length = inputs.input_ids.shape[1]
print("input_len", input_length)
print("run")
outputs = model.generate(
**inputs,
max_new_tokens=128,
do_sample=True,
temperature=0.7,
top_p=0.7,
top_k=50,
return_dict_in_generate=True
)
token = outputs.sequences[0, input_length:]
output_str = tokenizer.decode(token)
print("output :", output_str)
output : 東京は、日本の首都です。日本は、1947年に国立国会が設立された後、国際的に重要な地位を取った。東京は、世界の中心的な都市で、人口は3700000人、都市面積は2300平方キロメートル、文化、経済、政治、仕事のための都市です
。東京は、世界の最大の金属市場、世界の
ほー, とりあえず日本語はある程度いけそうっぽね.
(中身はいろいろ間違っているけど)
weight 変換おそい?
model のロードと (fp16)変換がめちゃ遅で, Ryzen9 3950x で動かしても 5~10 分くらいかかりました...
attention.py:148: UserWarning: Using `attn_impl: torch`. If your model does not use `alibi` or `prefix_lm` we recommend using `attn_impl: flash` otherwise we recommend using `attn_impl: triton`.
warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
Python プロセスは single thread で動いていたので, なにか attention 周りのモデル変換に時間かかっているのかしらん.
(他の transformers で fp16 変換しつつロードする LLM model も, ロードと変換に遅かった気もしなくはないが...)
Tesla P100 なので torch.compile
も使えず...
fp16 化したあとの weight を保存するようにするか, そのうち MLC-LLM や CTranslate2 あたりで precompile して高速化に期待でしょうか.
3090 で検証
Threadripper + 3090 で, bfloat16 で再度試して計測しました.
model load. secs = 215.182288646698
output : 東京は、日本の首都で、世界的に重要な都市です。東京は、古くから知られた地方で、日本の文化と伝統を蕎明らしています。東京は、世界の文化的な中心となっています。東京は、文化的に豊かで、食事的に豊富で、人文的に豊かで、自然的に豊かです。東京
run done. secs = 34.78921914100647
model load は 4 分くらい, 実行は 34 秒でした.
config.attn_config['attn_impl'] = 'triton'
は, flash-attn
がビルドできなかった(CUDA version mismatch)であきらめ.
flash-attn
がいけるともうちょっと高速化されるかもです.
max_new_tokens=128
なので, 入力 prompt 19 tokens も考慮すると
(128+19) tokens/ 34 sec = 4.32 tokens/sec
なんとか許容できるレベルでしょうか...
license あやしい?
Apache 2.0 ですが, 特に story writer で, それええんか? との疑惑は出てます.
しばし様子見したほうがよさそうです.
Discussion