💬

カメラ情報を用いてリアルタイムで状況をテキストに起こしたい (Ros2) part3

2023/09/28に公開

あれから、数回動作確認していて思いました。1フレームしか処理していないじゃないか!と... videoBLIPで学習させたモデルを使っているのにもったいない。ということで、しっかり動画像に反映させていきます。

勉強せねば...

ということで、さっそくvideoBLIPをより深く理解していこう思います。

demoを実行&&考察

それでは、まずはapp.pyを見ていきます。app.pyの16行目にある通り、model: VideoBlipForConditionalGenerationという、Blip2VisionModelを継承したクラスを用いています。Blip2VisionModel のシンプルな拡張版として使っているようです。基本的にBlip2VisionModelと使い方は同じようです。
次にapp.pyの38行目にprocess(processor, video=frames, text=context).to(model.device)という関数があります。この関数はgenerate関数のinputに格納する変数の設定をしています。
ここで重要なのは、videoパラメータのTensorの次元数だと思います。入力するビデオの型は、

a tensor of shape (batch, channel, time, height, width) or (channel, time, height, width)

となっているため、入力する前にエンコードする必要があります。
とりあえず、以上の点を踏まえてコードを書いていきます。

コード実装

まず、今回はさきほどのapp.pyの形式とほぼ同じように組んでみます。video_blip.modelから持ってきているclassや関数は、直接引っ張ってくることにします(そこまで多くない&複雑でないため)
ということで以下をそのまま引っ張ってきます。

コード詳細
import torch
import torch.nn as nn
from transformers import (
    AutoModelForCausalLM,
    AutoModelForSeq2SeqLM,
    BatchEncoding,
    Blip2Config,
    Blip2ForConditionalGeneration,
    Blip2Processor,
    Blip2QFormerModel,
    Blip2VisionModel,
)
from transformers.modeling_outputs import BaseModelOutputWithPooling


def process(
    processor: Blip2Processor,
    video: torch.Tensor | None = None,
    text: str | list[str] | None = None,
) -> BatchEncoding:
    """Process videos and texts for VideoBLIP.

    :param images: a tensor of shape (batch, channel, time, height, width) or
        (channel, time, height, width)
    """
    if video is not None:
        if video.dim() == 4:
            video = video.unsqueeze(0)
        batch, channel, time, _, _ = video.size()
        video = video.permute(0, 2, 1, 3, 4).flatten(end_dim=1)
    inputs = processor(images=video, text=text, return_tensors="pt")
    if video is not None:
        _, _, height, weight = inputs.pixel_values.size()
        inputs["pixel_values"] = inputs.pixel_values.view(
            batch, time, channel, height, weight
        ).permute(0, 2, 1, 3, 4)
    return inputs


class VideoBlipVisionModel(Blip2VisionModel):
    """A simple, augmented version of Blip2VisionModel to handle videos."""

    def forward(
        self,
        pixel_values: torch.FloatTensor | None = None,
        output_attentions: bool | None = None,
        output_hidden_states: bool | None = None,
        return_dict: bool | None = None,
    ) -> tuple | BaseModelOutputWithPooling:
        """Flatten `pixel_values` along the batch and time dimension, pass it
        through the original vision model, then unflatten it back.

        :param pixel_values: a tensor of shape (batch, channel, time, height, width)

        :returns:
            last_hidden_state: a tensor of shape (batch, time * seq_len, hidden_size)
            pooler_output: a tensor of shape (batch, time, hidden_size)
            hidden_states:
                a tuple of tensors of shape (batch, time * seq_len, hidden_size),
                one for the output of the embeddings + one for each layer
            attentions:
                a tuple of tensors of shape (batch, time, num_heads, seq_len, seq_len),
                one for each layer
        """
        if pixel_values is None:
            raise ValueError("You have to specify pixel_values")

        batch, _, time, _, _ = pixel_values.size()

        # flatten along the batch and time dimension to create a tensor of shape
        # (batch * time, channel, height, width)
        flat_pixel_values = pixel_values.permute(0, 2, 1, 3, 4).flatten(end_dim=1)

        vision_outputs: BaseModelOutputWithPooling = super().forward(
            pixel_values=flat_pixel_values,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=True,
        )

        # now restore the original dimensions
        # vision_outputs.last_hidden_state is of shape
        # (batch * time, seq_len, hidden_size)
        seq_len = vision_outputs.last_hidden_state.size(1)
        last_hidden_state = vision_outputs.last_hidden_state.view(
            batch, time * seq_len, -1
        )
        # vision_outputs.pooler_output is of shape
        # (batch * time, hidden_size)
        pooler_output = vision_outputs.pooler_output.view(batch, time, -1)
        # hidden_states is a tuple of tensors of shape
        # (batch * time, seq_len, hidden_size)
        hidden_states = (
            tuple(
                hidden.view(batch, time * seq_len, -1)
                for hidden in vision_outputs.hidden_states
            )
            if vision_outputs.hidden_states is not None
            else None
        )
        # attentions is a tuple of tensors of shape
        # (batch * time, num_heads, seq_len, seq_len)
        attentions = (
            tuple(
                hidden.view(batch, time, -1, seq_len, seq_len)
                for hidden in vision_outputs.attentions
            )
            if vision_outputs.attentions is not None
            else None
        )
        if return_dict:
            return BaseModelOutputWithPooling(
                last_hidden_state=last_hidden_state,
                pooler_output=pooler_output,
                hidden_states=hidden_states,
                attentions=attentions,
            )
        return (last_hidden_state, pooler_output, hidden_states, attentions)


class VideoBlipForConditionalGeneration(Blip2ForConditionalGeneration):
    def __init__(self, config: Blip2Config) -> None:
        # HACK: we call the grandparent super().__init__() to bypass
        # Blip2ForConditionalGeneration.__init__() so we can replace
        # self.vision_model
        super(Blip2ForConditionalGeneration, self).__init__(config)

        self.vision_model = VideoBlipVisionModel(config.vision_config)

        self.query_tokens = nn.Parameter(
            torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)
        )
        self.qformer = Blip2QFormerModel(config.qformer_config)

        self.language_projection = nn.Linear(
            config.qformer_config.hidden_size, config.text_config.hidden_size
        )
        if config.use_decoder_only_language_model:
            language_model = AutoModelForCausalLM.from_config(config.text_config)
        else:
            language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
        self.language_model = language_model

        # Initialize weights and apply final processing
        self.post_init()

前回、作成したNode classは以下のような感じです。

前回コード
class VideoBlipNode(Node):
    
    def __init__(self):
        super().__init__('video_blip_node')
        # rosparam initialization
        self.declare_parameter('image_topic_name', '/image_raw')
        self.declare_parameter('output_text_topic', '/blip/data')
        self.declare_parameter('model_name', 'kpyu/video-blip-opt-2.7b-ego4d')
        self.declare_parameter('question', '')

        # read params
        self.image_topic = self.get_parameter('image_topic_name').get_parameter_value().string_value
        self.output_topic = self.get_parameter('output_text_topic').get_parameter_value().string_value
        self.model_name = self.get_parameter('model_name').get_parameter_value().string_value
        self.prompt = self.get_parameter('question').get_parameter_value().string_value


        # pub sub
        self.image_subscription = self.create_subscription(Sensor_Image, self.image_topic, self.image_callback, qos.qos_profile_sensor_data)
        self.blip_publisher = self.create_publisher(String, self.output_topic, 10)

        #other params
        self.runnimg = False
        self.processor = None
        self.blip_model = None
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.bridge = CvBridge()

    def load_model(self):
        '''
        Load BLIP2 model
        '''
        self.get_logger().info('Loading model')

        processor = Blip2Processor.from_pretrained(self.model_name)
        model = Blip2ForConditionalGeneration.from_pretrained(
            self.model_name, torch_dtype=torch.float16
        )
        model.to(self.device)

        self.processor = processor
        self.blip_model = model
        self.get_logger().info('Loading end')

    def sensor_msg_convert_PIL(self, input_image: Sensor_Image):
        '''
        convert
        sensor Image -> PIL Image
        '''
        # self.get_logger().info('convert')
        self.runnimg = True
        try:
            cv_image = self.bridge.imgmsg_to_cv2(input_image, "bgr8")
        except CvBridgeError as e:
            print(e)

        pil_image = cv_image[:, :, ::-1]

        return pil_image
    
    def process_blip(self, image: PIL_Image):
        '''
        process blip and generate text
        '''
        self.get_logger().info('process')
        inputs = self.processor(images=image, text=self.prompt, return_tensors="pt").to(self.device, torch.float16)
        generated_ids = self.blip_model.generate(
                **inputs
            )
        generated_text = self.processor.batch_decode(
            generated_ids, 
            skip_special_tokens=True)[0].strip()
        
        return generated_text
    
    def image_callback(self, msg):
        self.get_logger().info('Subscription image')
        if not self.runnimg:
            pil_image = self.sensor_msg_convert_PIL(msg)
            get_text = self.process_blip(pil_image)
            pub_msg = String()
            pub_msg.data = get_text
            self.blip_publisher.publish(pub_msg)
            self.runnimg = False



def main():
    rclpy.init()
    node = VideoBlipNode()
    node.load_model()
    rclpy.spin(node)
    node.destroy_node()
    rclpy.shutdown

このクラスの上にさきほどのvideo_blip.modelから持ってきたclassをそのまま貼り付けます。
次にsensor_msgをPILに変換したこの関数を変更していきます。

    def sensor_msg_convert_PIL(self, input_image: Sensor_Image):
        '''
        convert
        sensor Image -> PIL Image
        '''
        # self.get_logger().info('convert')
        self.runnimg = True
        try:
            cv_image = self.bridge.imgmsg_to_cv2(input_image, "bgr8")
        except CvBridgeError as e:
            print(e)

        pil_image = cv_image[:, :, ::-1]

        return pil_image

ここで、returnしたpil_imageのtensor sizeを確認します。
すると[H,W,C]のような順番になっていました。ここでは、H:高さ,W:横幅,C:チャンネル数となります。この構造から[C,T,H,W],T:動画のフレーム分という構造に変化せていきます。
まずは画像を格納する時間配列を生成します。

input_images = torch.zeros(camera_fps, 3, camera_height, camera_width)

camera_fps: 格納するフレーム枚数になります。30fpsで30と入力すると動画がちょうど1秒になるといった感じです。
第二引数はchannel数
camera_height: 画像の縦size
camera_width: 画面の横size
といった感じになっています。
torch.zerosで作成すると真っ暗な動画像が生成されるので、そこにキューの要領で1フレームずつ画像を挿入していきます。以下がそのコードになります。

# スライス操作をして次元を移動,配列の先頭から最後の一つ前までを配列の1つ次から最後までに置き換える
input_images[0:-1] = input_images[1:].clone()
# 末にpermuteで次元を入れ替えた画像を代入する
input_images[-1] = torch.from_numpy(cv_image[:, :, ::-1].copy()).permute(2,0,1)
return input_image

実際作成したもの

これからたくさん評価していき、もう少し調整したいと考えています。

前回の記事

岐阜大学アレックス研究室

Discussion