カメラ情報を用いてリアルタイムで状況をテキストに起こしたい (Ros2) part3
あれから、数回動作確認していて思いました。1フレームしか処理していないじゃないか!と... videoBLIPで学習させたモデルを使っているのにもったいない。ということで、しっかり動画像に反映させていきます。
勉強せねば...
ということで、さっそくvideoBLIPをより深く理解していこう思います。
demoを実行&&考察
それでは、まずはapp.pyを見ていきます。app.pyの16行目にある通り、model: VideoBlipForConditionalGeneration
という、Blip2VisionModel
を継承したクラスを用いています。Blip2VisionModel のシンプルな拡張版として使っているようです。基本的にBlip2VisionModel
と使い方は同じようです。
次にapp.pyの38行目にprocess(processor, video=frames, text=context).to(model.device)
という関数があります。この関数はgenerate関数のinputに格納する変数の設定をしています。
ここで重要なのは、video
パラメータのTensorの次元数だと思います。入力するビデオの型は、
a tensor of shape (batch, channel, time, height, width) or (channel, time, height, width)
となっているため、入力する前にエンコードする必要があります。
とりあえず、以上の点を踏まえてコードを書いていきます。
コード実装
まず、今回はさきほどのapp.pyの形式とほぼ同じように組んでみます。video_blip.modelから持ってきているclassや関数は、直接引っ張ってくることにします(そこまで多くない&複雑でないため)
ということで以下をそのまま引っ張ってきます。
コード詳細
import torch
import torch.nn as nn
from transformers import (
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
BatchEncoding,
Blip2Config,
Blip2ForConditionalGeneration,
Blip2Processor,
Blip2QFormerModel,
Blip2VisionModel,
)
from transformers.modeling_outputs import BaseModelOutputWithPooling
def process(
processor: Blip2Processor,
video: torch.Tensor | None = None,
text: str | list[str] | None = None,
) -> BatchEncoding:
"""Process videos and texts for VideoBLIP.
:param images: a tensor of shape (batch, channel, time, height, width) or
(channel, time, height, width)
"""
if video is not None:
if video.dim() == 4:
video = video.unsqueeze(0)
batch, channel, time, _, _ = video.size()
video = video.permute(0, 2, 1, 3, 4).flatten(end_dim=1)
inputs = processor(images=video, text=text, return_tensors="pt")
if video is not None:
_, _, height, weight = inputs.pixel_values.size()
inputs["pixel_values"] = inputs.pixel_values.view(
batch, time, channel, height, weight
).permute(0, 2, 1, 3, 4)
return inputs
class VideoBlipVisionModel(Blip2VisionModel):
"""A simple, augmented version of Blip2VisionModel to handle videos."""
def forward(
self,
pixel_values: torch.FloatTensor | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
) -> tuple | BaseModelOutputWithPooling:
"""Flatten `pixel_values` along the batch and time dimension, pass it
through the original vision model, then unflatten it back.
:param pixel_values: a tensor of shape (batch, channel, time, height, width)
:returns:
last_hidden_state: a tensor of shape (batch, time * seq_len, hidden_size)
pooler_output: a tensor of shape (batch, time, hidden_size)
hidden_states:
a tuple of tensors of shape (batch, time * seq_len, hidden_size),
one for the output of the embeddings + one for each layer
attentions:
a tuple of tensors of shape (batch, time, num_heads, seq_len, seq_len),
one for each layer
"""
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
batch, _, time, _, _ = pixel_values.size()
# flatten along the batch and time dimension to create a tensor of shape
# (batch * time, channel, height, width)
flat_pixel_values = pixel_values.permute(0, 2, 1, 3, 4).flatten(end_dim=1)
vision_outputs: BaseModelOutputWithPooling = super().forward(
pixel_values=flat_pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
)
# now restore the original dimensions
# vision_outputs.last_hidden_state is of shape
# (batch * time, seq_len, hidden_size)
seq_len = vision_outputs.last_hidden_state.size(1)
last_hidden_state = vision_outputs.last_hidden_state.view(
batch, time * seq_len, -1
)
# vision_outputs.pooler_output is of shape
# (batch * time, hidden_size)
pooler_output = vision_outputs.pooler_output.view(batch, time, -1)
# hidden_states is a tuple of tensors of shape
# (batch * time, seq_len, hidden_size)
hidden_states = (
tuple(
hidden.view(batch, time * seq_len, -1)
for hidden in vision_outputs.hidden_states
)
if vision_outputs.hidden_states is not None
else None
)
# attentions is a tuple of tensors of shape
# (batch * time, num_heads, seq_len, seq_len)
attentions = (
tuple(
hidden.view(batch, time, -1, seq_len, seq_len)
for hidden in vision_outputs.attentions
)
if vision_outputs.attentions is not None
else None
)
if return_dict:
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooler_output,
hidden_states=hidden_states,
attentions=attentions,
)
return (last_hidden_state, pooler_output, hidden_states, attentions)
class VideoBlipForConditionalGeneration(Blip2ForConditionalGeneration):
def __init__(self, config: Blip2Config) -> None:
# HACK: we call the grandparent super().__init__() to bypass
# Blip2ForConditionalGeneration.__init__() so we can replace
# self.vision_model
super(Blip2ForConditionalGeneration, self).__init__(config)
self.vision_model = VideoBlipVisionModel(config.vision_config)
self.query_tokens = nn.Parameter(
torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)
)
self.qformer = Blip2QFormerModel(config.qformer_config)
self.language_projection = nn.Linear(
config.qformer_config.hidden_size, config.text_config.hidden_size
)
if config.use_decoder_only_language_model:
language_model = AutoModelForCausalLM.from_config(config.text_config)
else:
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
self.language_model = language_model
# Initialize weights and apply final processing
self.post_init()
前回、作成したNode classは以下のような感じです。
前回コード
class VideoBlipNode(Node):
def __init__(self):
super().__init__('video_blip_node')
# rosparam initialization
self.declare_parameter('image_topic_name', '/image_raw')
self.declare_parameter('output_text_topic', '/blip/data')
self.declare_parameter('model_name', 'kpyu/video-blip-opt-2.7b-ego4d')
self.declare_parameter('question', '')
# read params
self.image_topic = self.get_parameter('image_topic_name').get_parameter_value().string_value
self.output_topic = self.get_parameter('output_text_topic').get_parameter_value().string_value
self.model_name = self.get_parameter('model_name').get_parameter_value().string_value
self.prompt = self.get_parameter('question').get_parameter_value().string_value
# pub sub
self.image_subscription = self.create_subscription(Sensor_Image, self.image_topic, self.image_callback, qos.qos_profile_sensor_data)
self.blip_publisher = self.create_publisher(String, self.output_topic, 10)
#other params
self.runnimg = False
self.processor = None
self.blip_model = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.bridge = CvBridge()
def load_model(self):
'''
Load BLIP2 model
'''
self.get_logger().info('Loading model')
processor = Blip2Processor.from_pretrained(self.model_name)
model = Blip2ForConditionalGeneration.from_pretrained(
self.model_name, torch_dtype=torch.float16
)
model.to(self.device)
self.processor = processor
self.blip_model = model
self.get_logger().info('Loading end')
def sensor_msg_convert_PIL(self, input_image: Sensor_Image):
'''
convert
sensor Image -> PIL Image
'''
# self.get_logger().info('convert')
self.runnimg = True
try:
cv_image = self.bridge.imgmsg_to_cv2(input_image, "bgr8")
except CvBridgeError as e:
print(e)
pil_image = cv_image[:, :, ::-1]
return pil_image
def process_blip(self, image: PIL_Image):
'''
process blip and generate text
'''
self.get_logger().info('process')
inputs = self.processor(images=image, text=self.prompt, return_tensors="pt").to(self.device, torch.float16)
generated_ids = self.blip_model.generate(
**inputs
)
generated_text = self.processor.batch_decode(
generated_ids,
skip_special_tokens=True)[0].strip()
return generated_text
def image_callback(self, msg):
self.get_logger().info('Subscription image')
if not self.runnimg:
pil_image = self.sensor_msg_convert_PIL(msg)
get_text = self.process_blip(pil_image)
pub_msg = String()
pub_msg.data = get_text
self.blip_publisher.publish(pub_msg)
self.runnimg = False
def main():
rclpy.init()
node = VideoBlipNode()
node.load_model()
rclpy.spin(node)
node.destroy_node()
rclpy.shutdown
このクラスの上にさきほどのvideo_blip.modelから持ってきたclassをそのまま貼り付けます。
次にsensor_msgをPILに変換したこの関数を変更していきます。
def sensor_msg_convert_PIL(self, input_image: Sensor_Image):
'''
convert
sensor Image -> PIL Image
'''
# self.get_logger().info('convert')
self.runnimg = True
try:
cv_image = self.bridge.imgmsg_to_cv2(input_image, "bgr8")
except CvBridgeError as e:
print(e)
pil_image = cv_image[:, :, ::-1]
return pil_image
ここで、returnしたpil_imageのtensor sizeを確認します。
すると[H,W,C]
のような順番になっていました。ここでは、H:高さ,W:横幅,C:チャンネル数となります。この構造から[C,T,H,W]
,T:動画のフレーム分という構造に変化せていきます。
まずは画像を格納する時間配列を生成します。
input_images = torch.zeros(camera_fps, 3, camera_height, camera_width)
camera_fps
: 格納するフレーム枚数になります。30fpsで30と入力すると動画がちょうど1秒になるといった感じです。
第二引数はchannel数
camera_height
: 画像の縦size
camera_width
: 画面の横size
といった感じになっています。
torch.zeros
で作成すると真っ暗な動画像が生成されるので、そこにキューの要領で1フレームずつ画像を挿入していきます。以下がそのコードになります。
# スライス操作をして次元を移動,配列の先頭から最後の一つ前までを配列の1つ次から最後までに置き換える
input_images[0:-1] = input_images[1:].clone()
# 末にpermuteで次元を入れ替えた画像を代入する
input_images[-1] = torch.from_numpy(cv_image[:, :, ::-1].copy()).permute(2,0,1)
return input_image
実際作成したもの
これからたくさん評価していき、もう少し調整したいと考えています。
Discussion