🔀

8B の LLM を Lambda で頑張って動かしてみる

2025/01/16に公開1

こんにちは、初めましての方は初めまして。今年こそはダイエットするぞと思ってはや二週間、体重はまだ 0.1kg も減っていませんが「まあまだ 11 ヵ月あるし…」と逃げてばかりの 2025 年です。

突然ですが、LLM をサーバレスで動かしたくないですか? 僕はしたくなったので、前回 llama.cpp で LLM を AWS Lambda で動かしてみるという記事を書きました。この記事では量子化することで Lambda で 3B のモデルを動かしているのですが、まとめにも書いているように 7B のモデルでは Lambda でうまく動きませんでした。そのため 3B より大きいモデルについては動かすのを諦めていたのですが、先日夢の中で急な閃きがあり、「各層を分割すればいいじゃないか」ということに気付きました。気付いたからには試してみたくなり、実際に試して動かせることを確認したので記事を書きます。

GitHub にて一応実装を公開しています。
https://github.com/yukikawara/swallow-on-lambda

この記事の概要

  • 大きいモデルを動かしたいときはそれぞれの層をモジュールに分割することで動かせるようになる
  • 4 トークン出力するのに 30 分かかる

LLM について


https://en.wikipedia.org/wiki/Generative_pre-trained_transformer より引用

最近の LLM は GPT ベースのモデルが多く、GPT は図の左側のようにサブモジュールが連結しています。モジュール間では前のモジュールの計算結果が渡されています。この図を見ていたら何か見えてきませんか? そう、StepFunctions のフロー図とそっくりですね! そう見えてくると、各サブモジュールは Lambda に見えてきます。もちろん、各モジュールは全体のモデルよりもサイズが小さいため、メモリの上限が 10GB の Lambda でも何とか動きそうな気がしてきます。そこまで見えてくればあとは分かりますね? サブモジュールごとに重みを保存して、Lambda を用意して、Step Functions で動かすだけです! それでは以下で実際に実装をしていきましょう。

実装

今回は tokyotech-llm/Llama-3.1-Swallow-8B-Instruct-v0.3 を対象に実装します。

重みの保存

まず始めにモデルの構造を確認してみましょう。以下の実装でモデル構造を確認できます。

from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "tokyotech-llm/Llama-3.1-Swallow-8B-Instruct-v0.3"
model = AutoModelForCausalLM.from_pretrained(model_name)
print(model)

# >> Output
# LlamaForCausalLM(
#   (model): LlamaModel(
#     (embed_tokens): Embedding(128256, 4096)
#     (layers): ModuleList(
#       (0-31): 32 x LlamaDecoderLayer(
#         (self_attn): LlamaAttention(
#           (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
#           (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
#           (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
#           (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
#         )
#         (mlp): LlamaMLP(
#           (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
#           (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
#           (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
#           (act_fn): SiLU()
#         )
#         (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
#         (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
#       )
#     )
#     (norm): LlamaRMSNorm((4096,), eps=1e-05)
#     (rotary_emb): LlamaRotaryEmbedding()
#   )
#   (lm_head): Linear(in_features=4096, out_features=128256, bias=False)
# )

この構造を見てどのように保存するか考えましょう。今回は embed_tokens、32 層の layersnormrotary_emblm_head をそれぞれ Lambda を動かすことを考えてモデルを保存します。

torch.save(model.model.embed_tokens, "./embedder.pt")
torch.save(model.model.norm, "norm.pt")
torch.save(model.model.rotary_emb, "rotary_emb.pt")

for i, layer in enumerate(model.model.layers):
    gc.collect()
    torch.cuda.empty_cache()
    torch.save(layer, f"decoder_layer_{i}.pt")

torch.save(model.lm_head, "lm_head.pt")

各サブモジュールへは . でアクセスできるため、必要なサブモジュールをそれぞれ保存しています。これだけで 30GB 近くの容量が必要になるため、余裕を持って実行しましょう。

生成部分の実装


Step Functions 全体図

Step Functions の全体は上図のようになっています。流れは以下のようになっています。

  1. TokenizeEncoder で自然言語を番号列へと変換します。
  2. Embedding でエンベッディングを取得します。
  3. 層数分 DecoderLayer に通すことで最終的な出力を得ます。
  4. LMHead で各単語のスコアへと変換します。
  5. TokenizerDecoder で次の単語の番号を出力し、EOS トークンではない、もしくは最大単語数でなければ 2 へと戻ります。
  6. 得られた番号列を自然言語にして終了します。

これは実際のモデル内の処理にかなり近い挙動をしていると思います。それでは実際に実装について以下で説明していきます。

プロンプトの生成

公式ドキュメントに従って、ユーザーの入力からモデルへの入力を生成します。

from transformers import AutoTokenizer

model_name = "tokyotech-llm/Llama-3.1-Swallow-8B-Instruct-v0.3"
tokenizer = AutoTokenizer.from_pretrained(model_name)

message = [
    {"role": "system", "content": "あなたは誠実で優秀な日本人のアシスタントです。"},
    {
        "role": "user",
        "content": "東京の紅葉した公園で、東京タワーと高層ビルを背景に、空を舞うツバメと草地に佇むラマが出会う温かな物語を書いてください。",
    },
]
prompt = tokenizer.apply_chat_template(
    message, tokenize=False, add_generation_prompt=True
)
print(prompt)
# >> Output
# <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nあなたは誠実で優秀な日本人のアシスタントです。<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n東京の紅葉した公園で、東京タワーと高層ビルを背景に、空を舞うツバメと草地に佇むラマが出会う温かな物語を書いてください。<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n

実際には "東京の..." の部分はユーザーの入力になります。寡聞にして知らなかったのですが、apply_chat_template でモデルへの入力用のフォーマットに変換してくれるのはめちゃくちゃ便利ですね。これでユーザーへの入力の用意は出来ました。

実際にモデルへ入力する際は自然言語ではなく各単語に対応した番号の列を入力します。これは tokenizerencode 関数を使用することで変換できます(tokenizer.encode(prompt)

エンベッディング部分の実装

番号列へと変換出来たら、エンベッディングを取得します。エンベッディングモデルは事前に保存していたものを読み込んで使っていきます。

model_path = "/tmp/embedder.pt"
embedder = torch.load(model_path)
with torch.no_grad():
    embed_inputs = embedder(input_ids)

色々省いていますが、上記のように番号列をそのまま引数として渡すことでエンベッディングが取得できます。

DecoderLayer の実装

この部分は Huggingface の LlamaModel の実装を参考にし(というかパクリ)ました。

https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L517

from transformers.cache_utils import DynamicCache


model_path = f"/tmp/decoder_layer_{layer_num}.pt"
rotary_path = "/tmp/rotary_emb.pt"

decoder_layer = torch.load(model_path)
rotary_embedder = torch.load(rotary_path)

past_key_values = DynamicCache()
past_seen_tokens = past_key_values.get_seq_length()
cache_position = torch.arange(
    past_seen_tokens, past_seen_tokens + hidden_states.shape[1]
)
position_ids = cache_position.unsqueeze(0)
causal_mask = None
position_embeddings = rotary_embedder(hidden_states, position_ids)
with torch.no_grad():
    layer_outputs = decoder_layer(
        hidden_states,
        attention_mask=causal_mask,
        position_ids=position_ids,
        past_key_values=past_key_values,
        output_attenions=False,
        use_cache=True,
        cache_position=cache_position,
        position_embeddings=position_embeddings,
    )
hidden_states = layer_outputs[0]

基本的にはこんな感じで position_embeddings を生成して前の層の出力 (hidden_states) をモデルへと渡すだけです。これを層数分(今回は 32 回)繰り返します。

次単語の予測

ここでは一旦 RMSNorm に通した後に最後の単語に対する出力を使って単語ごとのスコアへと変換します。実装は以下です。

norm_path = "/tmp/norm.pt"
lm_head_path = "/tmp/lm_head.pt"

norm = torch.load(norm_path)
lm_head = torch.load(lm_head_path)

with torch.no_grad():
    hidden_states = norm(hidden_states)
    logits = lm_head(hidden_states[:, -1, :])

これで logits に単語分の長さを持った配列が格納されます。配列の各要素には単語のスコアが格納されています。

次に、生成するべき単語の計算をします。top_p などを用いて多様性を持たせたり、ビームサーチで出力する文の探索を行ったり出来ますが、今回は簡単のために一番スコアが大きい単語を選びます。

next_tokens = torch.argmax(torch.tensor(logits), dim=-1)
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)

これで input_ids に次の単語の番号が付与されます。この番号が EOS トークンか、事前に設定していたトークン数に達した場合、最終的に tokenizer.batch_decode 関数で自然言語にもどせば完成です。

実際に Step Functions で生成させてみる

ユーザーの入力を「東京の紅葉した公園で、東京タワーと高層ビルを背景に、空を舞うツバメと草地に佇むラマが出会う温かな物語を書いてください。」として、4 トークン生成させてみました。


Step Functions の実行結果


生成されたテキスト

「期間」のところを見ると分かるのですが、4 トークン生成させるだけで 30 分ほどかかっています。めちゃくちゃ遅いです。生成されたテキストを見ると「秋風がそ」まで生成されており、紅葉から秋の話を出力しようとしていることが分かります。本当はもっとたくさんトークンを生成させて出力を確かめてみたいのですが、とりあえず動くことは分かったので記事としては一旦終わりにします。

まとめ

この記事では 8B の LLM を Lambda で動かしてみました。実行自体はかなり遅く、実用にはほど遠い気がしますが、サーバレスでも(無理やり)LLM を動かすことが出来ました。ついでに実装も公開しているので、もし興味がある方は試してみてください(ついでに推論を速く出来るようにしてもらえると嬉しいな…と他力本願なことを思っています)

最後に宣伝になりますが、機械学習でビジネスの成長を加速するために、Fusic の機械学習チームがお手伝いたします。機械学習のPoCから運用まで、すべての場面でサポートした実績があります。もし、困っている方がいましたら、ぜひ Fusic にご相談ください。お問い合わせからでも気軽にご連絡いただけます。また Twitter の DM でのメッセージも大歓迎です。

GitHubで編集を提案
Fusic 技術ブログ

Discussion

ShinShin

すごいですね。勉強になります!