🖖

vLLMで独自実装モデルを高速推論させる

2024/12/04に公開

はじめに

チューリング生成AIチームの荒居です。
この記事は生成AIアドベントカレンダー2024の4日目の記事です。

この記事では、動画生成モデルを題材に、vLLMを用いて独自のマルチモーダルモデルを推論させる方法について解説します。vLLMはLLMの高速推論・サービングのライブラリで、LlamaやQwenなどの有名なモデルについてはサポートされているため非常に簡単に利用することが可能です。一方で、独自に実装したモデルを組み込む方法については公式ドキュメント以外ではほとんど情報がなく、かつ公式ドキュメントも情報が豊富とは言えないためなかなか手が出しづらい状況になっています。

そこで、本記事では、公式ドキュメントには書かれていない内容も含めて、独自のマルチモーダルモデルをvLLMに組み込む具体的な方法を丁寧に解説します。特に、動画生成モデルを例に、Hugging Face Transformersライブラリで実装されたモデルをどのようにvLLMに適合させるか、その手順を詳しく説明します。また、vLLMを活用することで得られる推論速度の向上についても実際の結果を交えて紹介します。

この記事で解説している内容:

  • vLLMで独自モデルを実装する方法
  • vLLMで独自のマルチモーダルデータを扱う方法

vLLMとは

vLLMは、大規模言語モデル(LLM)の推論を高速化し、モデルのサービングを簡単かつ効率的に行えるようにするためのオープンソースライブラリです。近年では、LLMの学習や実験にHugging FaceTransformersライブラリが広く利用されていますが、Transformersの実装では推論時にいくつかの非効率性が課題として挙げられています。その中でも特に指摘されているのが、Key-Value Cacheの管理における非効率性です。

Key-Value Cacheは、Transformerモデルが推論時に過去のコンテキスト情報を効率的に再利用するために使用される仕組みです。このキャッシュは、過去のトークンのアクティベーション結果を保存し、新しいトークンの生成時に再利用されます。しかし、Transformersライブラリの標準実装では、このキャッシュの管理が無駄なメモリ消費や処理遅延を引き起こしやすい構造になっており、LLM推論の速度を制約する要因となっています。

vLLMはこの課題に対し、独自のPagedAttentionアルゴリズムを採用することで解決を図っています。PagedAttentionは、Key-Value Cacheの割り当てと管理を最適化する仕組みで、キャッシュのメモリ無駄を大幅に削減するとともに、高速なアクセスと更新を実現します。この結果、vLLMを使用すると、従来のライブラリと比較して大幅に高い推論スループットを達成できます。


vLLMのPagedAttentionの図解。AttentionのKeyとValueはブロックという単位に切り分けられて保存されていることによりメモリを効率的に使うことができる。vLLMのブログより引用。

また、LLMサービングの際に効果があるContinuous Batchingや、量子化 (GPTQ、AWQ、INT4/8、FP8)対応なども実装されており、様々な高速化の取り組みがなされています。

vLLMはHugging Face Transformersとの高い互換性を持ち、有名なモデルについて言えば特別な作業を必要としません。そのため、開発スピードを重視するプロジェクトや、プロトタイピング段階では特に有用とされています。

vLLMにおけるマルチモーダルモデルの扱い

vLLMはLLMの高速推論・サービングのライブラリですが、画像や音声などの入力を言語と一緒に扱うマルチモーダルモデルについても推論の高速化を実現することができます。さらにいえば自己回帰のTransformerモデルであればどのようなモデルでも適切に実装すれば高速化することができます。

vLLMは、vllm.multimodalパッケージを通じてマルチモーダルモデルをサポートしています。ユーザーは、テキストやトークンのプロンプトとともに、vllm.inputs.PromptTypemulti_modal_dataフィールドを使用してマルチモーダル入力をモデルに渡すことができます。現在、vLLMは画像・動画データの組み込みサポートを提供していますが、他のモダリティを処理するために拡張することも可能です。

vLLMでは、マルチモーダルモデルを使用したオフライン推論の例も提供されています。例えば、単一の画像入力を用いた推論や、複数の画像を組み合わせた推論のサンプルコードが公式ドキュメントで紹介されています。

独自動画生成モデルの実装

このブログではチューリングで開発した自己回帰Transformerを用いている動画生成のモデル"Terra"をvLLMを用いて高速化することを考えます。

このモデルは、動画の各画像フレームをImage Tokenizerを用いて離散的なトークンの列に変換し、画像列を表しているトークンの列を入力として、未来の画像列の離散トークン列を予測します。予測された離散トークン列はその後Decoderを用いて画像列に変換され動画が生成されます。


動画生成モデルの概要図


動画生成モデルの入出力。条件付けの画像列と軌跡データを受け取り動画を生成する

また、条件付けのためにアクションと呼ばれるベクトル列を入れることができます。このベクトル列は3次元のベクトルが6個連なった3 x 6の行列となっており、各画像フレームの間に挿入されるようになっています。1枚の画像フレームを表している離散トークンは576個あるため、推論の際は画像の離散トークン576個ごとに6トークン分ベクトルが挿入されるような仕組みです。

Hugging FaceのTransformersにおける実装は以下のようになっています。このモデルはLlamaアーキテクチャのLLMモデルをベースとしていますが、アクションベクトルを扱うための仕組みを入れているところと、Positional Encodingとして学習可能な特殊なPositional Encodingを入れているところが差分としてあります。

Transformersを用いた実装
from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn

from transformers import LlamaConfig, LlamaForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast

from ..positional_embedding import LearnableFactorizedSpatioTemporalPositionalEmbedding

class LlamaActionConfig(LlamaConfig):
    model_type = "llama_action"

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.num_spatio_embeddings = kwargs.get("num_spatio_embeddings", 582)
        self.num_temporal_embeddings = kwargs.get("num_temporal_embeddings", 25)
        self.num_action_embeddings = kwargs.get("num_action_tokens", 6)
        self.num_image_patches = kwargs.get("num_image_patches", 576)
        self.action_dim = kwargs.get("action_dim", 3)


class LlamaActionForCausalLM(LlamaForCausalLM):
    config_class = LlamaActionConfig

    def __init__(self, config: LlamaActionConfig):
        super().__init__(config)

        self.num_spatio_embeddings = config.num_spatio_embeddings
        self.num_temporal_embeddings = config.num_temporal_embeddings
        self.num_image_patches = config.num_image_patches
        self.num_action_embeddings = config.num_action_embeddings

        self.pos_embedding_spatio_temporal = LearnableFactorizedSpatioTemporalPositionalEmbedding(
            config.num_spatio_embeddings, config.num_temporal_embeddings, config.hidden_size,
        )

        self.action_projection = nn.Linear(config.action_dim, config.hidden_size)

        self.post_init()

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        actions: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.Tensor]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithPast]:
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        if labels is not None:
            use_cache = False

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError(
                "You cannot specify both input_ids and inputs_embeds at the same time"
            )
        elif input_ids is not None:
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        inputs_embeds = self.model.get_input_embeddings()(input_ids)
        if past_key_values is None:
            inputs_embeds_list = torch.split(
                inputs_embeds,
                split_size_or_sections=self.num_image_patches,
                dim=1
            )
            actions_list = torch.split(
                actions,
                split_size_or_sections=self.num_action_embeddings,
                dim=1
            )

            embeddings = []
            if len(inputs_embeds_list) == len(actions_list):
                # 学習時はこちらが使われるが推論時はこうなることはほぼないはず
                for inputs_embeds, action_embeds in zip(inputs_embeds_list, actions_list):
                    action_features = self.action_projection(action_embeds)
                    embeddings.append(inputs_embeds)
                    embeddings.append(action_features)
            elif len(inputs_embeds_list) < len(actions_list):
                # 推論時は生成の途中であることが多いはずなのでこちらが使われることが多い
                for i, inputs_embeds in enumerate(inputs_embeds_list):
                    embeddings.append(inputs_embeds)
                    if i < len(inputs_embeds_list) - 1:
                        # 最後のフレームは生成途中の画像トークン列の可能性があるのでアクションEmbeddingを追加しない
                        action_embeds = self.action_projection(actions_list[i])
                        embeddings.append(action_embeds)
                if inputs_embeds_list[-1].size(1) == self.num_image_patches:
                    # 画像のトークンがちょうど1フレーム分出ている場合はアクションEmbeddingを加えた上で次のフレーム用のテキストトークンをさらに追加する
                    action_embeds = self.action_projection(actions_list[len(inputs_embeds_list) - 1])
                    embeddings.append(action_embeds)
        else:
            past_key_values_length = past_key_values[0][0].size(2)
            embeddings = []
            # image, image, ..., image, action, action, ..., actionのような並びで入力を作る
            # image tokenのみ生成を行うため、1フレーム分の生成が終わったタイミングでaction tokenを追加してあげる
            if past_key_values_length % self.num_spatio_embeddings == (self.num_spatio_embeddings - self.num_action_embeddings):
                seq_index = past_key_values_length // self.num_spatio_embeddings + 1
                actions_list = torch.split(
                    actions,
                    split_size_or_sections=self.num_action_embeddings,
                    dim=1
                )
                action_features = self.action_projection(actions_list[seq_index - 1])
                embeddings.append(action_features)
                embeddings.append(inputs_embeds)
            else:
                pass

        if len(embeddings) > 0:
            inputs_embeds = torch.cat(embeddings, dim=1)

        # Spatio Temporal Positional Embeddingの挿入
        past_key_values_length = past_key_values[0][0].size(2) if past_key_values is not None else 0
        inputs_embeds += self.pos_embedding_spatio_temporal(inputs_embeds, past_key_values_length)

        outputs = self.model(
            input_ids=None,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]
        logits = self.lm_head(sequence_output).contiguous()

        loss = None
        if labels is not None:
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))

        if not return_dict:
            output = (logits,) + outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        attention_mask=None,
        use_cache=None,
        **kwargs):
        batch_size = input_ids.size(0)
        seq_length = input_ids.size(1)
        n_frames = seq_length // self.num_image_patches
        attention_mask_length = n_frames * (self.num_image_patches + self.num_action_embeddings)
        if seq_length % self.num_image_patches != 0:
            n_last_frame_tokens = seq_length % self.num_image_patches
            attention_mask_length += n_last_frame_tokens
        else:
            print(f"attempting to generate new frame - frame no: {n_frames + 1}")
        attention_mask = torch.ones((batch_size, attention_mask_length), device=input_ids.device, dtype=torch.long)

        # cut decoder_input_ids if past_key_values is used
        if past_key_values is not None:
            past_length = past_key_values[0][0].size(2)
            if input_ids.size(1) > past_length:
                remove_prefix_length = past_length
            else:
                remove_prefix_length = input_ids.size(1) - 1
            input_ids = input_ids[:, remove_prefix_length:]
            seq_length = input_ids.size(1)
            past_key_values_length = past_key_values[0][0].size(2)
            mask_seq_length = seq_length + past_key_values_length
            if past_key_values_length % self.num_spatio_embeddings == (self.num_spatio_embeddings - self.num_action_embeddings):
                mask_seq_length += self.num_action_embeddings
            attention_mask = torch.ones((batch_size, mask_seq_length), device=input_ids.device, dtype=torch.long)

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "actions": kwargs.get("actions"),
            "past_key_values": past_key_values,
            "use_cache": use_cache,
        }
Positional Encodingの詳細

時刻ステップtの画像フレーム\mathbf{i}_tを表現するトークンをi_t^l~(l \in \{1, 2, \dots, 576\})とし、時刻ステップtのアクションベクトル列\mathbf{a}_tの一つ一つのベクトルをa_t^k~(k\in \{1,2,3,4,5,6\})とすると、トークン列の並びは、

i_1^1,i_1^2,\dots,i_1^{576},a_1^1,a_1^2,\dots,a_1^6,i_2^1,\dots,i_2^{576},a_2^1,\dots,a_2^6,\dots

のようになります。今回用いているPositional Encodingは同じ時間ステップ内での位置を指定するSpatial Positional Encodingと、時間ステップについて位置を表現したTemporal Positional Encodingを分解したものになっています。時間ステップが同じトークンは画像トークンが576でアクショントークンが6あるため582個あります。したがってSpatial Positional Encodingは同じ時間ステップ内の582個の位置を表現したものになっています。一方で今回のモデルでは最大25フレーム分までを一度に扱えるようにしているためTemporal Positional Encodingは25個の位置を表現したものになっています。

import torch
import torch.nn as nn


class LearnableFactorizedSpatioTemporalPositionalEmbedding(nn.Module):
    def __init__(self, num_spatio_embeddings: int, num_temporal_embeddings: int, embedding_dim: int):
        super().__init__()
        self.spatio_embeddings = nn.Embedding(num_spatio_embeddings, embedding_dim)
        self.temporal_embeddings = nn.Embedding(num_temporal_embeddings, embedding_dim)
        self.num_spatio_embeddings = num_spatio_embeddings
        self.num_temporal_embeddings = num_temporal_embeddings

    def forward(self, attention_mask: torch.LongTensor, past_key_values_length):
        seq_length = attention_mask.size(1)
        batch_size = attention_mask.size(0)

        if past_key_values_length == 0:
            # [0, 1, 2, ..., num_spatio_embeddings-1, 0, 1, 2, ..., num_spatio_embeddings-1, ...]という形のテンソルを作成
            spatio_indices = torch.arange(
                self.num_spatio_embeddings,
                device=attention_mask.device
            ).repeat(self.num_temporal_embeddings).unsqueeze(0).repeat((batch_size, 1))

            # [0, 0, 0, ..., 1, 1, 1, ..., 2, 2, 2, ...]という形のテンソルを作成
            temporal_indices = torch.arange(
                self.num_temporal_embeddings,
                device=attention_mask.device
            ).repeat_interleave(self.num_spatio_embeddings).unsqueeze(0).repeat((batch_size, 1))

            spatio_indices = spatio_indices[:, :seq_length]
            temporal_indices = temporal_indices[:, :seq_length]
            
        else:
            temporal_index = past_key_values_length // self.num_spatio_embeddings
            spatio_index = past_key_values_length % self.num_spatio_embeddings
            spatio_indices = torch.tensor([[spatio_index]], device=attention_mask.device).repeat((batch_size, 1))
            temporal_indices = torch.tensor([[temporal_index]], device=attention_mask.device).repeat((batch_size, 1))

        return self.spatio_embeddings(spatio_indices) + self.temporal_embeddings(temporal_indices)

vLLMモデルへの変換の方針

今回扱うモデルでは通常のLLMにおいて言語のトークン列として扱われる部分が画像を表現している離散トークンということになり、マルチモーダル入力にあたるものがアクションベクトル列にあたります。

また、アクションベクトル列はフレーム数分だけ存在するため、複数のマルチモーダル入力があることになります。これは、vLLMのドキュメントにある、複数の画像を入力とするVLMの推論のケースに似ています。

上記のように今回のケースではvLLMの利用事例としてやや一般的ではない部分がいくつかあるためその点に注意しながらvLLMで実装する方針を考えていきます。

なお、vLLMのデータ処理の流れは Input Processing Pipelineで紹介されているため、適宜参照すると理解が進みやすいです。

入力トークン列の用意

vLLMのモデルには[2, 45641, 35, 653, 18, 5, 1383, 9, 5, 2274, 116, 31652, 35]のような離散トークン列が入力されることになるため、今回のケースでは画像列は事前に離散トークンの列に変換されている必要があります。

また、マルチモーダル入力が入る部分については事前にプレースホルダーで埋めておく必要があります。例えば、上記の離散トークン列の先頭に画像のEmbeddingを32個挿入することがわかっている場合、プレースホルダー用の特殊トークン(例えば-1など)を32個分入れた入力を入れる必要があります。従って上の例で挙げている離散トークン列は最終的には[-1, -1, ..., -1, 2, 45641, 35, 653, 18, 5, 1383, 9, 5, 2274, 116, 31652, 35]のような形式で入力されることになります。

LLaVAなどの有名なモデルはプレースホルダーをどうやって準備している?

プレースホルダーを準備するために推奨されている方法は、input_processorを実装するやり方です。例えばLLaVA-1.5のモデルではinput_processor_for_llava()という関数を実装しその中で、入力されたプロンプトの中の特殊なトークン部分を展開してLLaVA-1.5の画像入力に使われるトークン数分のプレースホルダーIDを挿入するようになっています。

ただしこのやり方ではデータの流れが秘匿されて分かりづらくなるため、今回紹介する実装では明示的に入力するトークン列にプレースホルダー分を確保しておき、input_processorは実装しないようにしています。

複数のマルチモーダル入力の入れ方

今回のモデルでは複数のマルチモーダル入力が入ってくることになるため、入力トークン列においては、マルチモーダル入力が入ってくる部分それぞれについて特殊なトークンでプレースホルダーを作る必要があります。

また、Hugging Face実装の際は576トークン生成するごとに6トークン分アクションベクトル列を付け加える操作をforwardメソッド内で行い、generate()メソッドを一回呼び出せば画像フレーム列が生成できるようになっていましたが、vLLM実装ではコンテキストとなる画像フレーム列の離散トークン列とアクションベクトル列を与えると次の画像フレームの576トークン分のみを生成するようにし、generate()メソッドを生成したい画像フレームの枚数分呼び出すようにします。

アクションベクトル列の入力方法

vLLMでは、離散トークン列またはテキスト以外の入力は全てマルチモーダルデータとして扱われ、以下のように離散トークン列/テキストとは分けて入力されます。

from PIL import Image
from vllm import LLM, SamplingParams


inputs = {
    "prompt": "Describe the image. Picture 1: <img></img>\n",
    "multi_modal_data": {
        "image":  Image.open("/path/to/image").convert("RGB")
    }
}

llm = LLM(
    model="Qwen/Qwen-VL",
    trust_remote_code=True,
    max_model_len=1024,
    max_num_seqs=2
)
sampling_params = SamplingParams(temperature=0.2, max_tokens=64, stop_token_ids=None)
outputs = llm.generate(inputs, sampling_params=sampling_params)
generated_text = outputs[0].outputs[0].text
print(generated_text)

上の例では画像がマルチモーダルデータとして与えられており、imageというキーが画像というモダリティのデータであることを示しているため、vllm.multimodal.image.ImagePluginによってPIL.Imageの状態で入力された画像データはPyTorchのTensorへと変換されて最終的にはモデルに入力されます。

これと同じようにして、独自のモダリティのデータを入力したい場合は以下のように

inputs = {
    "prompt": "...",
    "multi_modal_data": {
        "actions": torch.tensor([[0.0, 0.0, 0.0],
                                [0.0, 1.8, 0.5]])
    }
}

<modality>: <data>の形式でデータを用意します。ここで、データを処理するPluginは2024年12月現在はimagevideoしか用意されていないため、独自に実装する必要があります。 Pluginの実装については後述します。

出力の取り扱い

一般的なLLMの推論の場合、vLLMではテキストを入力すると自動でトークナイザを適用して離散トークン列に変換し、それをTransformerに入力してトークン列の続きを生成し、最後に生成された部分をデトークナイズして文章の形にして出力します。

今回のモデルは入力は事前に画像トークナイザで離散化したトークン列になり、出力は別途後段で行うため、トークナイザを適用する処理をオフにして行う必要があります。この方法についても後述します。

vLLMモデルの実装

上記で検討した方針に基づき、vLLMを利用して実装したのが以下になります。

vLLMを用いた実装
import sys
from array import array
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Mapping
from pathlib import Path

import numpy as np
import torch
from torch import nn
from transformers import LlamaConfig, AutoConfig, AutoModel

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
from vllm.inputs import InputContext, INPUT_REGISTRY
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
    get_compressed_tensors_cache_scale)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs, MultiModalPlugin
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.utils import is_hip

from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP
from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
                    make_empty_intermediate_tensors_factory, make_layers)

# HF実装の方にパスを通す
sys.path.append(str(Path(__file__).parent.parent))
# HF実装の方をimportしておく
from models.llama_action import LlamaActionConfig, LlamaActionForCausalLM

# HF実装のモデルをHFに登録しておく
AutoConfig.register("llama_action", LlamaActionConfig)
AutoModel.register(LlamaActionConfig, LlamaActionForCausalLM)


# 独自実装の"action"モダリティ用のプラグイン。実態としては入力されたデータを素通しする。
class ActionsPlugin(MultiModalPlugin):
    def get_data_key(self) -> str:
        return "actions"

    def _default_input_mapper(self, ctx: InputContext, data: object | List[object], **mm_processor_kwargs) -> MultiModalInputs:
        return MultiModalInputs({"actions": data})

    def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
        return 4096


MULTIMODAL_REGISTRY.register_plugin(ActionsPlugin())


# 推論用のPositional Encoding実装
class LearnableFactorizedSpatioTemporalPositionalEmbedding(nn.Module):
    def __init__(self, num_spatio_embeddings: int, num_temporal_embeddings: int, embedding_dim: int):
        super().__init__()
        self.spatio_embeddings = nn.Embedding(num_spatio_embeddings, embedding_dim)
        self.temporal_embeddings = nn.Embedding(num_temporal_embeddings, embedding_dim)
        self.num_spatio_embeddings = num_spatio_embeddings
        self.num_temporal_embeddings = num_temporal_embeddings

    def forward(self, positions: torch.Tensor):
        spatio_indices = positions % self.num_spatio_embeddings
        temporal_indices = positions // self.num_spatio_embeddings
        return self.spatio_embeddings(spatio_indices) + self.temporal_embeddings(temporal_indices)


# 事前の動作チェックに使われる関数
def get_max_action_tokens(ctx: InputContext):
    hf_config = ctx.get_hf_config(LlamaActionConfig)
    num_action_tokens = hf_config.num_action_embeddings
    num_frames = hf_config.num_temporal_embeddings - 1
    return num_action_tokens * num_frames


# 事前の動作チェックに使われる関数
def create_dummy_data(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]):
    hf_config = ctx.get_hf_config(LlamaActionConfig)

    num_frames = hf_config.num_temporal_embeddings
    vocab_size = hf_config.vocab_size
    num_action_tokens = hf_config.num_action_embeddings
    num_image_tokens = hf_config.num_image_patches
    dummy_seq = []
    np.random.seed(0)
    for i in range(num_frames - 1):
        dummy_image_tokens = np.random.randint(0, vocab_size, num_image_tokens).tolist()
        dummy_seq.extend(dummy_image_tokens)
        dummy_action_tokens = [-3] * num_action_tokens
        dummy_seq.extend(dummy_action_tokens)
    seq_data = SequenceData(array("l", dummy_seq))

    action = torch.tensor([
        [0.0, 0.0, 0.0],
        [0.0, 2.0, 0.5],
        [0.0, 4.0, 1.0],
        [0.0, 6.0, 1.5],
        [0.0, 8.0, 2.0],
        [0.0, 10.0, 2.5],
        [0.0, 12.0, 3.0],
        [0.0, 14.0, 3.5],
        [0.0, 16.0, 4.0],
    ])
    actions = []
    for _ in range(num_frames - 1):
        actions.append(action[:num_action_tokens])
    actions = torch.cat(actions, dim=0)
    mm_data = {"actions": actions}
    return seq_data, mm_data


@MULTIMODAL_REGISTRY.register_input_mapper(data_type_key="actions")
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens("actions", get_max_action_tokens)
@INPUT_REGISTRY.register_dummy_data(create_dummy_data)
class VLLMLlamaActionForCausalLM(nn.Module, SupportsMultiModal):
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
    }

    # LoRA specific attributes
    supported_lora_modules = [
        "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
        "lm_head"
    ]
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings"
    }
    embedding_padding_modules = ["lm_head"]

    # BitandBytes specific attributes
    default_bitsandbytes_target_modules = [
        ".gate_proj.",
        ".down_proj.",
        ".up_proj.",
        ".q_proj.",
        ".k_proj.",
        ".v_proj.",
        ".o_proj.",
    ]
    # in TP, these weights are partitioned along the column dimension (dim=-1)
    column_parallel_weights_modules = [".down_proj.", ".o_proj."]
    bitsandbytes_stacked_params_mapping = {
        # shard_name, weight_name, index
        "q_proj": ("qkv_proj", 0),
        "k_proj": ("qkv_proj", 1),
        "v_proj": ("qkv_proj", 2),
        "gate_proj": ("gate_up_proj", 0),
        "up_proj": ("gate_up_proj", 1),
    }

    # Mistral/Llama models can also be loaded with --load-format mistral
    # from consolidated.safetensors checkpoints
    mistral_mapping = {
        "layers": "model.layers",
        "attention": "self_attn",
        "wq": "q_proj",
        "wk": "k_proj",
        "wv": "v_proj",
        "wo": "o_proj",
        "attention_norm": "input_layernorm",
        "feed_forward": "mlp",
        "w1": "gate_proj",
        "w2": "down_proj",
        "w3": "up_proj",
        "ffn_norm": "post_attention_layernorm",
        "tok_embeddings": "model.embed_tokens",
        "output": "lm_head",
        "norm": "model.norm"
    }

    def __init__(
        self,
        config: LlamaActionConfig,
        multimodal_config: MultiModalConfig,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
    ) -> None:
        super().__init__()
        self.config = config
        self.multimodal_config = multimodal_config

        self.num_spatio_embeddings = config.num_spatio_embeddings
        self.num_temporal_embeddings = config.num_temporal_embeddings
        self.num_image_patches = config.num_image_patches
        self.num_action_embeddings = config.num_action_embeddings

        self.pos_embedding_spatio_temporal = LearnableFactorizedSpatioTemporalPositionalEmbedding(
            num_spatio_embeddings=self.num_spatio_embeddings,
            num_temporal_embeddings=self.num_temporal_embeddings,
            embedding_dim=config.hidden_size,
        )

        self.action_projection = nn.Linear(config.action_dim, config.hidden_size)

        self.model = LlamaModel(config,
                                cache_config,
                                quant_config,
                                lora_config=None,
                                prefix="model")
        if get_pp_group().is_last_rank:
            self.unpadded_vocab_size = config.vocab_size
            self.lm_head = ParallelLMHead(
                self.unpadded_vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
                padding_size=DEFAULT_VOCAB_PADDING_SIZE,
                quant_config=quant_config,
            )
            if config.tie_word_embeddings:
                self.lm_head = self.lm_head.tie_weights(
                    self.model.embed_tokens)

            logit_scale = getattr(config, "logit_scale", 1.0)
            self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                    config.vocab_size,
                                                    logit_scale)
            self.sampler = Sampler()
        else:
            self.lm_head = PPMissingLayer()
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        **kwargs: object,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        """Forward pass for the model.
        input_ids already accounts for the positions of the to-be-inserted action embeddings.
    
        action tokens are represetnted by -3.
        example: [1287, 3342, ..., 6571, -3, ..., -3]
        """
        if intermediate_tensors is not None:
            input_ids = None
            inputs_embeds = None
        else:
            action_token_indices = (input_ids == -3).nonzero(as_tuple=True)[0]
            image_token_indices = (input_ids > 0).nonzero(as_tuple=True)[0]

            image_tokens = input_ids[image_token_indices]
            image_token_embeddings = self.model.get_input_embeddings(image_tokens)

            inputs_embeds = torch.zeros(
                (input_ids.size(0), image_token_embeddings.size(1)), 
                device=input_ids.device, dtype=image_token_embeddings.dtype
            )
            inputs_embeds[image_token_indices] = image_token_embeddings

            actions = kwargs.pop("actions", None)
            if actions is not None:
                assert len(action_token_indices) == actions.size(0) * actions.size(1), "actions must have the same length as the number of action tokens"
                actions = actions.to(dtype=self.action_projection.weight.dtype)
                action_embeddings = self.action_projection(actions)
                inputs_embeds[action_token_indices] = action_embeddings.view(-1, action_embeddings.size(-1))
            input_ids = None
            inputs_embeds += self.pos_embedding_spatio_temporal(positions)
        hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors, inputs_embeds=inputs_embeds)
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

    def sample(self, logits: torch.Tensor,
               sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

    def load_kv_cache_scales(self, quantization_param_path: str) -> None:
        self.model.load_kv_cache_scales(quantization_param_path)

    # This function is used to remap the mistral format as
    # used by Mistral and Llama <=2
    def maybe_remap_mistral(
        self,
        name: str,
        loaded_weight: torch.Tensor,
    ) -> Tuple[str, torch.Tensor]:
        def permute(w: torch.Tensor, n_heads: int):
            attn_in = self.config.head_dim * n_heads
            attn_out = self.config.hidden_size

            return w.view(n_heads, attn_in // n_heads // 2, 2,
                          attn_out).transpose(1, 2).reshape(attn_in, attn_out)

        mapping = self.mistral_mapping
        modules = name.split(".")

        # rotary embeds should be sliced
        if "wk" in modules:
            loaded_weight = permute(loaded_weight,
                                    self.config.num_key_value_heads)
        elif "wq" in modules:
            loaded_weight = permute(loaded_weight,
                                    self.config.num_attention_heads)

        for item in modules:
            if item in mapping and mapping[item] not in name:
                name = name.replace(item, mapping[item])

        return name, loaded_weight

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=(["lm_head."]
                           if self.config.tie_word_embeddings else None),
        )
        loader.load_weights(
            self.maybe_remap_mistral(name, loaded_weight)
            for name, loaded_weight in weights)

以下で重要なポイントについて解説します。

HF実装を登録する

vLLMは重みを読み込む際にHugging Faceの実装を参照するため、Hugging Face版の独自モデルの実装をHugging Face Transformersに登録しておく必要があります。

以下の箇所でHugging Face Transformersに独自実装のモデルを登録しています。

import sys

# HF実装の方にパスを通す
sys.path.append(str(Path(__file__).parent.parent))
# HF実装の方をimportしておく
from models.llama_action import LlamaActionConfig, LlamaActionForCausalLM

# HF実装のモデルをHFに登録しておく
AutoConfig.register("llama_action", LlamaActionConfig)
AutoModel.register(LlamaActionConfig, LlamaActionForCausalLM)

独自のモダリティ用のPluginを用意する

先述のように、vLLM側で実装されているimagevideo以外のモダリティを扱うためには、独自実装のPluginを用意する必要があります。

以下の箇所で独自のモダリティである"actions"向けのPluginを実装しています。今回はmulti_modal_dataを用意する時点でモデルにそのまま入力できる形式にしてあるため、特に処理は施さずに返すような実装になっています。

from vllm.inputs import InputContext
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs, MultiModalPlugin

# 独自実装の"action"モダリティ用のプラグイン。実態としては入力されたデータを素通しする。
class ActionsPlugin(MultiModalPlugin):
    def get_data_key(self) -> str:
        return "actions"

    def _default_input_mapper(self, ctx: InputContext, data: object | List[object], **mm_processor_kwargs) -> MultiModalInputs:
        return MultiModalInputs({"actions": data})

    def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
        return 4096


MULTIMODAL_REGISTRY.register_plugin(ActionsPlugin())

Positional Encodingの実装変更

Hugging Face版の実装では学習時にも推論時にも対応できるようにするためPositional Encodingはどちらの場合でも対応できるような分岐が実装されていましたが、vLLM版は推論時のみを考慮すれば良いため分岐を廃したものになっています。

また、Hugging Face Transformersでは位置を表すIDが明示的に用いられていないのですが、vLLMでは各トークンの位置を示すpositionsというテンソルが自動的に生成されているため、これをそのままPositional Embeddingを作成する際に利用しています。

import torch
import torch.nn as nn


# 推論用のPositional Encoding実装
class LearnableFactorizedSpatioTemporalPositionalEmbedding(nn.Module):
    def __init__(self, num_spatio_embeddings: int, num_temporal_embeddings: int, embedding_dim: int):
        super().__init__()
        self.spatio_embeddings = nn.Embedding(num_spatio_embeddings, embedding_dim)
        self.temporal_embeddings = nn.Embedding(num_temporal_embeddings, embedding_dim)
        self.num_spatio_embeddings = num_spatio_embeddings
        self.num_temporal_embeddings = num_temporal_embeddings

    def forward(self, positions: torch.Tensor):
        spatio_indices = positions % self.num_spatio_embeddings
        temporal_indices = positions // self.num_spatio_embeddings
        return self.spatio_embeddings(spatio_indices) + self.temporal_embeddings(temporal_indices)

ダミー入力用の関数の実装

vLLMではモデルの初期化を行う際に、自動でダミーデータを生成してそれをモデルに通し正常に動作が行われるかをチェックする機構が実装されています。LLMモデルであればこのダミーデータは離散トークン列のみで良いのですが、マルチモーダルモデルの場合は入力されるマルチモーダルデータを模擬したデータも用意しておく必要があります。

動作チェックに用いられる関数で実装が必要なものはマルチモーダルデータの使用トークン数の最大値を返す関数と、ダミーデータを生成する関数の二つになります。

今回は、一つのアクションベクトル列が6トークン分に相当し、それが画像のフレームごとに存在するため、6\times (25 - 1) = 144トークン分になります。画像のフレーム数を-1しているのは最後の画像フレームに対応したアクションベクトル列は入力しないためです。

以下の実装では数字はハードコーディングせずconfigから読み出すようにしてマルチモーダルデータの使用トークン数の最大値を返す関数を実装しています。

from vllm.inputs import InputContext

from models.llama_action import LlamaActionConfig

# 事前の動作チェックに使われる関数
def get_max_action_tokens(ctx: InputContext):
    hf_config = ctx.get_hf_config(LlamaActionConfig)
    num_action_tokens = hf_config.num_action_embeddings
    num_frames = hf_config.num_temporal_embeddings - 1
    return num_action_tokens * num_frames

また、ダミーデータを返す関数では、ダミーデータとして離散トークン列とアクションベクトル列それぞれについてダミーデータを生成する実装を行なっています。なお、先述した通りvLLMでは離散トークン列にマルチモーダルデータで置換される部分についてのプレースホルダーを用意しておく必要があるため、今回は全て-3が入っている部分をプレースホルダーとして扱うように実装しました。

from typing import Mapping

import numpy as np
import torch
from vllm.inputs import InputContext
from vllm.sequence import SequenceData

from models.llama_action import LlamaActionConfig


def create_dummy_data(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]):
    hf_config = ctx.get_hf_config(LlamaActionConfig)

    num_frames = hf_config.num_temporal_embeddings
    vocab_size = hf_config.vocab_size
    num_action_tokens = hf_config.num_action_embeddings
    num_image_tokens = hf_config.num_image_patches
    dummy_seq = []
    np.random.seed(0)
    # 離散トークン列についてダミー入力を生成する
    for i in range(num_frames - 1):
        dummy_image_tokens = np.random.randint(0, vocab_size, num_image_tokens).tolist()  # 1枚の画像フレーム分のトークンをランダムに生成する
        dummy_seq.extend(dummy_image_tokens)
        dummy_action_tokens = [-3] * num_action_tokens  # アクションが入る部分のプレースホルダーを用意する
        dummy_seq.extend(dummy_action_tokens)
    seq_data = SequenceData(array("l", dummy_seq))

    # アクションベクトル列のテンプレート
    action = torch.tensor([
        [0.0, 0.0, 0.0],
        [0.0, 2.0, 0.5],
        [0.0, 4.0, 1.0],
        [0.0, 6.0, 1.5],
        [0.0, 8.0, 2.0],
        [0.0, 10.0, 2.5],
        [0.0, 12.0, 3.0],
        [0.0, 14.0, 3.5],
        [0.0, 16.0, 4.0],
    ])
    # アクションベクトル列についてダミー入力を生成する
    actions = []
    for _ in range(num_frames - 1):
        actions.append(action[:num_action_tokens])
    actions = torch.cat(actions, dim=0)
    mm_data = {"actions": actions}
    return seq_data, mm_data

自作のPluginやダミーデータ関連の関数の登録

vLLM版のLlamaActionForCausalLMの実装であるVLLMLlamaActionForCausalLMに入力されるデータが上で実装したActionsPluginを使えるようにしたり、ダミーデータ生成関連の関数をダミーデータ生成時に使うようにするために以下のようにデコレータを使用して関数やPluginの登録を行います。

from vllm.inputs import INPUT_REGISTRY
from vllm.multimodal import MULTIMODAL_REGISTRY


@MULTIMODAL_REGISTRY.register_input_mapper(data_type_key="actions")
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens("actions", get_max_action_tokens)
@INPUT_REGISTRY.register_dummy_data(create_dummy_data)
class VLLMLlamaActionForCausalLM(nn.Module, SupportsMultiModal):
    ... 

Multi-modalデータを使ってプレースホルダーを置き換える

入力のinput_idsの中で-3という特殊なトークンが割り当てられているところはプレースホルダーになっているため、Transformerに入力する前にアクションベクトルで置き換えてあげる必要があります。今回のvLLM実装ではforward()メソッドの中で実施しています。

以下のように-3が入っている箇所についてインデックスを取得し、その部分についてはアクションベクトル列をaction_projectionを通して出てきたベクトル列で置き換えます。また、それ以外の部分についてはinput_idsEmbedding層に通して得られたベクトル列で置き換えます。この実装は以下の箇所で実装されています。

def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    kv_caches: List[torch.Tensor],
    attn_metadata: AttentionMetadata,
    intermediate_tensors: Optional[IntermediateTensors] = None,
    **kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
    """Forward pass for the model.
    input_ids already accounts for the positions of the to-be-inserted action embeddings.

    action tokens are represetnted by -3.
    example: [1287, 3342, ..., 6571, -3, ..., -3]
    """
    if intermediate_tensors is not None:
        # こちらは2回目以降にforward()が呼び出された時 = cacheが使われている時
        input_ids = None
        inputs_embeds = None
    else:
        # cacheが使われていない時はこちらの分岐が動く
        # プレースホルダーが入れられているインデックスを入手
        action_token_indices = (input_ids == -3).nonzero(as_tuple=True)[0]
        # プレースホルダー以外のトークンがあるインデックスを入手
        image_token_indices = (input_ids > 0).nonzero(as_tuple=True)[0]

        image_tokens = input_ids[image_token_indices]
        # プレースホルダー以外のトークンについてはEmbeddingに変える
        image_token_embeddings = self.model.get_input_embeddings(image_tokens)

        # Transformerの入力となるデータの配列を作成する
        inputs_embeds = torch.zeros(
            (input_ids.size(0), image_token_embeddings.size(1)), 
            device=input_ids.device, dtype=image_token_embeddings.dtype
        )
        # プレースホルダー以外の部分については画像トークンのEmbeddingで置き換える
        inputs_embeds[image_token_indices] = image_token_embeddings

        actions = kwargs.pop("actions", None)
        if actions is not None:
            assert len(action_token_indices) == actions.size(0) * actions.size(1), "actions must have the same length as the number of action tokens"
            actions = actions.to(dtype=self.action_projection.weight.dtype)
            action_embeddings = self.action_projection(actions)
            # プレースホルダー部分についてはアクションベクトルのEmbeddingで置き換える
            inputs_embeds[action_token_indices] = action_embeddings.view(-1, action_embeddings.size(-1))
        input_ids = None

学習済み重みの名前のマッピングの設定

vLLMでは線形層やAttentionの実装においてもvLLM独自のMergedColumnPrallelLinearQKVParallelLinearというクラスで実装されていたりすることがあり、HF版の実装とvLLM版の実装で重みにつけられた名前が異なる場合があります。この時、HF版で学習された重みをvLLM版の実装で読み込むために名前のマッピングを定義し、名前を読み替えてロードを行うようになっています。

この名前のマッピングは今回はベースとなっているvllm.model_executor.models.llama.LlamaForCausalLMのものをそのまま使っています。LlamaForCausalLMでは次のようにwqq_projと読み替えられたり、outputlm_headと読み替えられたりしています。

class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
    }

    # LoRA specific attributes
    supported_lora_modules = [
        "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
        "lm_head"
    ]
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings"
    }
    embedding_padding_modules = ["lm_head"]

    # BitandBytes specific attributes
    bitsandbytes_stacked_params_mapping = {
        # shard_name, weight_name, index
        "q_proj": ("qkv_proj", 0),
        "k_proj": ("qkv_proj", 1),
        "v_proj": ("qkv_proj", 2),
        "gate_proj": ("gate_up_proj", 0),
        "up_proj": ("gate_up_proj", 1),
    }

    # Mistral/Llama models can also be loaded with --load-format mistral
    # from consolidated.safetensors checkpoints
    mistral_mapping = {
        "layers": "model.layers",
        "attention": "self_attn",
        "wq": "q_proj",
        "wk": "k_proj",
        "wv": "v_proj",
        "wo": "o_proj",
        "attention_norm": "input_layernorm",
        "feed_forward": "mlp",
        "w1": "gate_proj",
        "w2": "down_proj",
        "w3": "up_proj",
        "ffn_norm": "post_attention_layernorm",
        "tok_embeddings": "model.embed_tokens",
        "output": "lm_head",
        "norm": "model.norm"
    }

    (中略)
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=(["lm_head."]
                           if self.config.tie_word_embeddings else None),
        )
        return loader.load_weights(
            self.maybe_remap_mistral(name, loaded_weight)
            for name, loaded_weight in weights)

    # This function is used to remap the mistral format as
    # used by Mistral and Llama <=2
    def maybe_remap_mistral(
        self,
        name: str,
        loaded_weight: torch.Tensor,
    ) -> Tuple[str, torch.Tensor]:

        def permute(w: torch.Tensor, n_heads: int):
            attn_in = self.config.head_dim * n_heads
            attn_out = self.config.hidden_size

            return w.view(n_heads, attn_in // n_heads // 2, 2,
                          attn_out).transpose(1, 2).reshape(attn_in, attn_out)

        mapping = self.mistral_mapping
        modules = name.split(".")

        # rotary embeds should be sliced
        if "wk" in modules:
            loaded_weight = permute(loaded_weight,
                                    self.config.num_key_value_heads)
        elif "wq" in modules:
            loaded_weight = permute(loaded_weight,
                                    self.config.num_attention_heads)

        for item in modules:
            if item in mapping and mapping[item] not in name:
                name = name.replace(item, mapping[item])

        return name, loaded_weight

vLLMモデルでの推論

さて、ここまでの流れで推論用のモデルであるVLLMLlamaActionForCausalLMを実装してきました。これを使って推論を行う方法もHugging Faceのモデルの場合とやや異なるので解説します。

モデルの初期化

まずモデルの読み込みは以下のようにして行います。

import torch
from vllm import LLM, ModelRegistry

# 上で実装したモデルがこのmodeling_llama_action.pyに実装されているとする
from modeling_llama_action import VLLMLlamaActionForCausalLM


device = torch.device("cuda:0")
path = "/path/to/pretrained_weight"
model = LLM(
    model=path,
    skip_tokenizer_init=True,
    enforce_eager=True,
    max_num_seqs=5,
    device=device,
)

ここでポイントとなってくるのはskip_tokenizer_initenforce_eagerの引数です。

まず、skip_tokenizer_initについてです。普通のLLMはテキストをトークンに変えるTokenizerがセットになっていることが一般的なので、vLLMでは同梱のTokenizerを使ってテキストを自動でトークンに変える機能が備わっており、skip_tokenizer_init=False(デフォルトの設定)とすると、同梱のTokenizerをconfigファイルから読み込んでくれます。一方で、今回のモデルではテキストではなく画像を離散的なトークン列に変えるImage Tokenizerを用いているため、skip_tokenizer_init=TrueとしてTokenizerを読み込む機能をオフにしておかないとエラーになります。

また、enforce_eagerについてですが、これはvLLMの機能をオフライン推論に限定するために必要なオプションになります。vLLMではLLMのサービングのために、ストリーミング入力を受け取るCUDA Graphを用いることができるようになっていますが、ストリーミング入力を用いている場合一部のPyTorchの演算は実行できないようになっています。今回の例では、プレースホルダー用のtokenが使われている部分のインデックスを取得する演算がストリーミング入力に対応していないため、enforce_eager=Trueにしていないと最初の動作チェックのところでエラーになります。

action_token_indices = (input_ids == -3).nonzero(as_tuple=True)[0]

モデルでの推論

今回の実装では画像フレーム1枚分のトークンを生成し終わったら、生成されたトークンを入力したプロンプトの後ろに結合し、さらにアクションベクトル用のプレースホルダーをつけて再度モデルの入力にするような方式で画像フレーム列(の離散表現)を生成していきます。コードで示すと以下のようになります。

import torch
from vllm import SamplingParams


n_context_frames = 3
n_frames = 25
n_frames_to_generate = n_frames - n_context_frames
# 事前にImageTokenizerで画像をトークナイズした上でアクションベクトル分のプレースホルダーを挿入しておく
prompt_tokens = [23126, 12318, ..., 8997, -3, -3, -3, -3, -3, -3]
actions = torch.tensor([[0.2, 2.4, 0.5],
                        [0.4, 5.2, 1.0],
                        ...
                        [0.0, 20.2, 2.5]])  # (n_frames * 6, 3)
inputs = [
    {
        "prompt_token_ids": prompt_tokens,
        "multi_modal_data": {"actions": actions[:6 * n_context_frames]}
    }
]
sampling_params = SamplingParams(temperature=1.0, detokenize=False, max_tokens=576, stop_token_ids=None)
all_outputs = []
for step in range(n_frames_to_generate):
    outputs = model.generate(inputs, sampling_params=sampling_params)[0].outputs[0].token_ids
    all_outputs.append(outputs)
    # 前のステップで生成された内容とプレースホルダーをつけて新たなプロンプトを作成する
    prompt = torch.cat([prompt, torch.tensor(outputs), torch.ones(6, dtype=torch.long) * -3])
    inputs = [
        {
            "prompt_token_ids": prompt.tolist(),
            "multi_modal_data": {"actions": actions[:6 * (n_context_frames + step + 1)]}
        }
    ]

今回のケースではトークン列からテキストに直す必要がないためdetokenize=Falseに設定をしています。

速度計測

これでvLLM版の実装は終わりです。最後に速度計測をしてみましょう。PyTorchで計算の実行時間を計測する場合には計算が実行されている部分のコードブロックをtorch.cuda.synchronizeで囲ってあげる必要があるため、その処理だけ仕込んでHF版の実装とvLLM版の実装を比べてみました。


HF版実装での生成


vLLM版実装での生成

速度計測の結果は以下の表のようになりました。

vLLM版実装 HF版実装
2.2秒分の生成時間 93.53 s 242.68 s
1フレームあたりの時間 4.25 s 11.03 s

HF版実装では1フレームあたりの生成に11秒かかっているところをvLLM版の実装では4.25秒で生成できていることになるため、およそ2.6倍高速に生成ができていることになります。

また、以下の二つのgifがそれぞれHF版実装とvLLM版実装を用いて生成された動画ですが、ほとんど同じであることが見て取れるのではないでしょうか?


HF版実装で生成された動画


vLLM版実装で生成された動画

終わりに

この記事では、vLLMを用いて独自実装のマルチモーダルモデルを推論させる方法を、動画生成モデルを題材に紹介しました。この記事がvLLMを用いて独自モデルの推論を実践したい人の一助になれば幸いです。

最後に宣伝ですが、私の所属するチューリングの生成AIチームは、自動運転への活用を目指してVision & Language Models (VLMs)や動画生成モデルなどさまざまな生成AIの開発を行っています。この記事を読んで興味を持っていただけた方は、ぜひ一緒に働けることを楽しみにしています。
詳細については、X(Twitter)のDMにご連絡いただくか、チューリングの採用情報をご覧ください。

https://tur.ing/jobs

Appendix

HF実装の学習時の挙動

上記のTransformersの実装では、最初に画像のトークンが全てのフレーム分繋がった状態でinput_idsに入力されます。また、アクションベクトル列も全てのフレーム分繋がった状態で入力されます。input_idsはその後Embeddingレイヤにより2048次元のベクトルの系列に変換されたあと576トークンずつの塊に分割されます。

一方、アクションベクトル列は6トークンずつの塊に分割されたあと、それぞれがLinearレイヤによってより2048次元のベクトルの系列に変換されます。

input_idsから作られたベクトルの系列と、アクションベクトル列から作られたベクトルの系列はその後交互に結合されて最終的には一つのベクトル系列へと変えられ、Transformerへの入力になります。


学習時のforward()メソッド内でのデータ処理の流れ

HF実装の推論時の挙動

今回のモデルは、数フレーム程度の短い動画を入力として与えるとその続きを生成するようなモデルになっているため、推論時は最初に数フレーム程度の画像列を入力します。また、アクション指示については未来のフレームのものについても入力します。例えば、最初の3フレーム分を入れてその続きを生成させる場合も、アクション指示については25フレーム分入力するようになっています。

このとき、画像のトークンを576トークン生成するごとに、その画像と対応したアクションベクトル列を挿入するようにして推論が行われます。


推論時のforward()メソッド内でのデータ処理の流れ

Tech Blog - Turing

Discussion