🤖

Japanese MiniGPT-4: rinna 3.6bとBLIP-2を組み合わせてマルチモーダルチャットのモデルを作る

2023/07/27に公開

はじめに

LLMの応用先の一つに,テキストに加えて画像や音声といった複数のモーダルの入出力を行うマルチモーダル情報処理があります.例えば,2023年3月に発表されたGPT-4の論文では,テキストと画像から構成されるプロンプトを入力することで,画像の内容に関して高度な対話を実現できることが報告されています.GPT-4のように,テキスト以外の情報を考慮して対話を行うタスクはマルチモーダルチャットと呼ばれています.

マルチモーダルチャットを実現する方法として,テキストの情報のみで事前学習されたLLMを改良し,マルチモーダル情報を扱えるようにする手法が多数提案されています.例えば,テキストと画像を入力可能なタスクであれば,画像データで学習された画像のエンコーダとなるモデルをLLMに接続することで,画像とテキストを同じ枠組みで処理する手法(BLIP-2MiniGPT-4)が提案されています.

今回の記事では,rinnaで行われているjapanese-gpt-neox-3.6bを用いたマルチモーダルチャットの取り組みについて紹介します.現在は上述のMiniGPT-4の手法をベースに,japanese-gpt-neox-3.6bとBLIP-2を組み合わせたJapanese MiniGPT-4を作成しています.以降,MiniGPT-4の概要,MiniGPT-4でjapanese-gpt-neox-3.6bを利用する方法,学習データの概要,学習ずみのモデルの動作例について説明します.この記事の内容を参考にMiniGPT-4とjapanese-gpt-neox-3.6bを組み合わせて任意のデータで学習していただくことで,日本語のマルチモーダルチャットのモデルを構築することが可能になります.

MiniGPT-4


MiniGPT-4のモデル構造(https://arxiv.org/abs/2304.10592

MiniGPT-4のモデル構造を上記に示します.MiniGPT-4は,パラメタを固定したLLM(Vicuna),および,パラメタを固定した画像のエンコーダ(BLIP-2ViTQ-Former)を組み合わせたモデルです.モデルの特徴は,LLMと画像のエンコーダを接続する部分に,画像の埋め込み表現を言語の埋め込み表現に変換する単一のLinear Layerをアダプタのように使用している点です.単一のLinear Layerのみを学習させることで,学習にかかるコストを大幅に抑えることが可能です.

学習は二段階で実施されており,大規模な画像キャプションのデータで一段階目の学習を行った後,高品質な小規模の画像キャプションのデータで学習を行っています.一段階目のデータとしては,Conceptual 12M(CC12M)等のデータセットが利用されており,そのサイズ(画像とキャプションのペア数)は5Mサンプルです.二段階目のデータとしては,一段階目のデータで学習されたモデルが画像に対して生成したキャプションをChatGPTで修正したものが使用されており,そのサイズは3.5Kサンプルです.

学習方法として,一段階目の学習には4枚のA100(80GB)が10時間使用されており,二段階目の学習には1枚のA100(80GB)が7分間使用されています.LLMや画像のエンコーダのパラメタを固定して学習することで,大規模なデータにも関わらず比較的短時間で学習が実施できていることがわかります.なお,二段階目の学習を実施した理由は,一段階目の学習のみでは,モデルの生成結果に繰り返しや情報の欠落等の問題があったためと報告されています.

使用されているプロンプトの例を下記に示します(改行を挿入していますが,実際は1行のテキストです).プロンプトは大きく###Human:###Assistant:で構成されています.画像は<Img></Img>のタグで囲われており,<ImageFeature>はLinear Layerから出力された画像の埋め込み表現(一つの画像につき,32個のベクトルで表現)に置き換えられます.

###Human: <Img><ImageFeature></Img> Describe this image in detail.
Give as many details as possible. Say everything you see.
###Assistant:

発現したとされている能力として下記が報告されており,Linear Layerのパラメタを学習するのみでも,マルチモーダルチャットが実現可能であることが確認できます.

  • 画像の概要を生成
  • 手書き画像からウェブサイトを作成
  • 画像からの物語やポエムを生成
  • 問題を示す画像(例:洗濯機から泡が溢れてしまっている画像)に対して解決策を生成
  • 料理の画像から手順を生成

LLMの差し替え

MiniGPT-4はソースコードが公開されていますので,オリジナルのモデルで利用されているVicunaを異なるLLMに差し替えて学習させることができます.しかし問題として,MiniGPT-4で用いられているBLIP-2のQ-Former(画像のエンコーダ)は英語のキャプションを生成するように学習されており,英語以外のLLMに差し替えたときにうまく動くかはわかりませんでした.今回,日本語のLLMとしてjapanese-gpt-neox-3.6bを利用し,MiniGPT-4と同等の学習を行ったところ,日本語でのキャプション生成や質問応答が可能なことが確認できました.ここではその実装方法について簡単に説明します.

MiniGPT-4のLLMを差し替えるためには,モデルの初期化,および,モデルのforwardについてコードの変更が必要です.まず,モデルの初期化については,元のモデルでVicunaのトークナイザとモデル(LLaMA)が読み込まれていますので,MiniGPT-4/minigpt4/models/mini_gpt4.py#L85-L104のコードを,GPT-NeoXのものに差し替えます.

from transformers import AutoTokenizer


llm_model = "rinna/japanese-gpt-neox-3.6b"

# 中略

print('Loading LLM', flush=True)
self.llm_tokenizer = AutoTokenizer.from_pretrained(llm_model, use_fast=False)

if self.low_resource:
    self.llm_model = CustomizedGPTNeoXForCausalLM.from_pretrained(
        llm_model,
        torch_dtype=torch.float16,
        load_in_8bit=True,
        device_map={'': device_8bit}
    )
else:
    self.llm_model = CustomizedGPTNeoXForCausalLM.from_pretrained(
        llm_model,
        torch_dtype=torch.float16,
    )

for name, param in self.llm_model.named_parameters():
    param.requires_grad = False
print('Loading LLM Done')

ここでは,GPTNeoXForCausalLMではなく,少し修正を加えた後述のCustomizedGPTNeoXForCausalLMを利用しています.decoder-onlyのモデルであるGPT-NeoXの現在の実装では,forward関数の引数に埋め込み表現を直接指定するinput_embedsがサポートされていないため,このプルリクエストを参考に,下記のようにinput_embedsを指定できるように変更しています.

from transformers.models.gpt_neox import GPTNeoXForCausalLM


class CustomizedGPTNeoXForCausalLM(GPTNeoXForCausalLM):
    def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs):
        input_shape = input_ids.shape

        # cut decoder_input_ids if past is used
        if past_key_values and past_key_values[0] is not None:
            input_ids = input_ids[:, -1:]

        position_ids = kwargs.get("position_ids", None)
        if attention_mask is not None and position_ids is None:
            # create position_ids on the fly for batch generation
            position_ids = attention_mask.long().cumsum(-1) - 1 
            position_ids.masked_fill_(attention_mask == 0, 1)
            if past_key_values:
                position_ids = position_ids[:, -1].unsqueeze(-1)

        # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
        if attention_mask is None:
            attention_mask = input_ids.new_ones(input_shape)

        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids}

        model_inputs.update(
            {
                "attention_mask": attention_mask,
                "position_ids": position_ids,
                "past_key_values": past_key_values,
            }
        )
        return model_inputs

forwardについても,同様にLLMをjapanese-gpt-neox-3.6bに差し替えるための変更を行います.下記はその実装例として,MiniGPT-4/minigpt4/models/mini_gpt4.py#L169-L209に相当する部分です.画像のみを入力し,キャプションを生成するためのinput_embedsの作成が実装されています.具体的には,プロンプトの内容,および,トークナイザのパディングや終端記号をjapanese-gpt-neox-3.6bに対応するものに変更ししています.

# template
PROMPT_START = "ユーザー: <IMG>"
PROMPT_AFTER = "</IMG><NL>システム: "

# create templates
p_start_encodings = self.llm_tokenizer(PROMPT_START, return_tensors="pt", add_special_tokens=False).to(image.device) # 1 x 6
p_after_encodings = self.llm_tokenizer(PROMPT_AFTER, return_tensors="pt", add_special_tokens=False).to(image.device) # 1 x 13

# create inputs
input_texts = [t + self.llm_tokenizer.eos_token for t in samples["target"]]
encodings = self.llm_tokenizer(input_texts, return_tensors="pt",
    padding="longest", truncation=True, max_length=self.max_txt_len, add_special_tokens=False
).to(image.device)

# create targets
targets = encodings.input_ids.masked_fill(encodings.input_ids == self.llm_tokenizer.pad_token_id, -100)
empty_targets = torch.ones(
    [batch_size, 1 + p_start_encodings.input_ids.size(1) + image_feature_len + p_after_encodings.input_ids.size(1)],
    dtype=torch.long
).to(image.device).fill_(-100)
targets = torch.cat([empty_targets, targets], dim=1)

# create bos embeddings
bos = torch.ones([batch_size, 1], dtype=p_start_encodings.input_ids.dtype,
    device=p_start_encodings.input_ids.device
) * self.llm_tokenizer.bos_token_id
bos_embeds = self.llm_model.gpt_neox.embed_in(bos)
atts_bos = torch.ones(bos.shape, device=image.device)

# create prompt_before image embeddings
p_start = p_start_encodings.input_ids.expand(batch_size, -1)
p_start_embeds = self.llm_model.gpt_neox.embed_in(p_start)
atts_p_start = torch.ones(p_start.shape, device=image.device)

# create prompt_after image embeddings
p_after = p_after_encodings.input_ids.expand(batch_size, -1)
p_after_embeds = self.llm_model.gpt_neox.embed_in(p_after)
atts_p_after = torch.ones(p_after.shape, device=image.device)

# create inputs embeddings
inputs_embeds = self.llm_model.gpt_neox.embed_in(encodings.input_ids)

# merge embeddings
inputs_embeds = torch.cat([bos_embeds, p_start_embeds, img_embeds, p_after_embeds, inputs_embeds], dim=1)
attention_mask = torch.cat([atts_bos, atts_p_start, atts_img, atts_p_after, encodings.attention_mask], dim=1)

学習データ

MiniGPT-4の論文では,学習データとして5Mペアの画像と英語のキャプションが利用されており,高品質なモデルを構築するためには,オリジナルのMiniGPT-4と同様に,画像と日本語のキャプションがペアになった大規模なデータセットを学習に用いることが理想的です.現在rinnaでは,独自に構築した大規模データセット,および,オープンソースのデータセットを組み合わせることで,マルチモーダルチャットのモデルを学習しています.

画像に日本語のテキストが付与されたオープンソースのデータセットには代表的なものとして,MS COCO 2014に人手でキャプションを付与したSTAIR Captions(画像数:118K,キャプション数:590K)や,Visual Genome Datasetに日本語の質問応答を付与したJapanese Visual Genome VQA dataset(画像数:99K,QA数:793K)などがあります.

MiniGPT-4の論文では,一段階目の学習のみを経たモデルには,モデルの生成結果に繰り返しや情報の欠落等の問題があったと報告されています.しかしながら,japanese-gpt-neox-3.6bと上記のような日本語データセットを用いて実験している範囲では,論文で報告されているような問題は見られていません.そのため,この問題は利用しているLLMに依存して発生しているか,もしくは,元のMiniGPT-4の実装に起因して発生している可能性があると考えられます.

動作例

ここでは,現在開発中のマルチモーダルチャットの動作例を紹介します.入力画像として,BLIPのデモで利用されている画像(入力画像1)と,いらすとやの画像(入力画像2)の2枚を使用してみます.

まず,下記に入力画像1と共にいくつかプロンプトを入力したときのモデルの入出力を示します.


サンプルとして使用する入力画像1(https://github.com/salesforce/BLIP/blob/main/demo.ipynb

prompt1: ユーザー: <IMG><ImageHere></IMG><NL>システム: 
output1: 砂浜に座って犬と女性が遊んでいる

prompt2: ユーザー: <IMG><ImageHere></IMG> 何が写ってる?<NL>システム: 
output2: 犬が写っています

prompt3: ユーザー: <IMG><ImageHere></IMG> 時間帯は?<NL>システム: 
output3: 夕方

prompt4: ユーザー: <IMG><ImageHere></IMG> 犬の犬種は?<NL>システム: 
output4: ゴールデンレトリバー

prompt1では,プロンプトとして画像のみを<ImageHere>を置き換える形で入力しており,画像のキャプションとして,正しく「砂浜に座って犬と女性が遊んでいる」が出力されていることがわかります.prompt2では,<ImageHere>の画像に加えて「何が写ってる?」という質問を入力しており,応答として,「犬が写っています」と出力されています.また,prompt3では「時間帯は?」という質問に対して「夕方」と応答しており,prompt4では,「犬の犬種は?」という質問に対して,「ゴールデンレトリバー」と応答しています.

続いて,下記に入力画像2と共にいくつかプロンプトを入力したときのモデルの入出力を示します.


サンプルとして使用する入力画像2(https://www.irasutoya.com/2023/04/blog-post_527.html

prompt1: ユーザー: <IMG><ImageHere></IMG><NL>システム: 
output1: カエルの被り物

prompt2: ユーザー: <IMG><ImageHere></IMG> どんな表情をしてる?<NL>システム: 
output2: 困った表情

prompt3: ユーザー: <IMG><ImageHere></IMG> どんなテイストのイラストですか?<NL>システム: 
output3: かわいい系

prompt1では,画像のキャプションとして正しく「カエルの被り物」が出力されていることがわかります.prompt2では,「どんな表情をしてる?」という質問に対して「困った表情」と応答しており,prompt3では,「どんなテイストのイラストですか?」という質問に対して「かわいい系」と応答しており,どちらも悪くない応答です.この結果から,今回学習したモデルは,入力画像1のような写真だけでなく,入力画像2のようなイラストについても対応できていることが確認できました.

上記の二つの例からわかる通り,今回の記事で作成しているモデルは,英語のキャプションのデータで学習されたBLIP-2と日本語のLLMであるjapanese-gpt-neox-3.6bを組み合わせたものであるものの,入力された画像に対して,日本語のキャプションを生成したり,画像と共に入力された日本語の質問に対してある程度妥当な応答が出力できるということが明らかになりました.

まだ完全な応答とは言えないものもあり,例えば入力画像1において,prompt2では,画像の一部(犬)のみが言及されていたり,prompt4では,犬の犬種は正確にはラブラドール・レトリバーであり内容が間違っていたりということがあります.これらの課題への対応については,例えばプロンプトを工夫したり,学習データを増やしたりすることで改善していきたいと考えています.

まとめ

今回の記事では,rinnaで取り組まれているjapanese-gpt-neox-3.6bを用いたマルチモーダルチャットのモデル構築(Japanese MiniGPT-4)について紹介しました.MiniGPT-4のLLMをjapanese-gpt-neox-3.6bに差し替えて学習することで,日本語でもマルチモーダルチャットが実現できることが確認できました.今後は,現在構築中のモデルをChararuに導入し,テキストだけでなく画像も交えたチャットの機能をリリースする予定です.また,画像のようなマルチモーダルの情報を扱うことで,テキストの情報に留まらない多様な情報を考慮した対話が実現できるため,引き続き画像を含むさまざまなマルチモーダル情報を考慮した対話技術の研究開発を行っていきます.

Discussion