🤹‍♂️

Multimodal RAG を実装してみる

2024/11/15に公開

昨日の記事の続き。
Multimodal RAG のアプローチのうち、マルチモーダル埋め込みを用いるもの(Multi-Vector Retriever for RAG on tables, text, and images のOption 1)の具体的な実装を考えてみる。

前提:非マルチモーダルRAG

通常のRAG でドキュメントを格納する場合、以下のようなコードを用いるのが一般的だと思う。

#ドキュメントを分割してチャンクにする
text_splitter = RecursiveCharacterTextSplitter(chunks_size=300, chunk_overlap=30)
chunks = text_splitter.split_documents(document)

#ベクトルインデックスを作成(Google Vertex AI)
vector_store = VectorSearchVectorStore.from_components(embedding=embedding_model)

#ベクトルインデックスにチャンクをベクトル化して格納
vector_store.add_documents(chunks)

VectorSearchVectorStore の実装を辿れば分かるが、add_documents() の処理では、ベクトルインデックス作成時に設定したベクトル埋め込みモデルembedding を使って入力テキストをベクトル化している。

Azure のベクトルインデックスAzureSearch 等でも同様である(ベクトル埋め込みモデルの引数名こそ異なるが)。

マルチモーダルRAG の実装

実装状況の問題

それならベクトルインデックスに画像を埋め込むのはadd_images() とか使えば良いんじゃない?と思ってしまうが、add_images()実装されていないベクトルインデックスも多い

よく使われそうなベクトルインデックスについて実装状況まとめたのが以下(2024/11 時点)。

各LangChain コミュニティの開発状況までは追えていないので、今後add_images() が実装されるのか、実装する気が無いのか(パブクラのオブジェクトストレージ接続機能を使えとか)は分からないが、とにかく実装されていない以上は利用できない。

代案:add_embeddings() を使う

add_embeddings() は、ベクトルを直接ベクトルインデックスに格納する処理である。これも未実装のベクトルインデックスも多いのだが、add_images() よりは実装されている模様。

add_embeddings() が実装されているベクトルインデックスには、以下の手順で参考情報の画像を格納できる。(埋め込みモデルが対応していれば、画像以外の形式でも同じ)

  1. ベクトルインデックスに設定した埋め込みモデルと同じモデルで、画像をベクトルに変換する。
  2. そのベクトルを、add_embeddings() でベクトルインデックスに格納する。

プログラム

今回は、add_embeddings() の実装があるFAISS でベクトルインデックスを作成してみる。
pip install が必要なPython パッケージや、Google 認証については省略。

埋め込みモデルを定義

Google のマルチモーダル埋め込みモデル を利用する。

from vertexai.vision_models import MultiModalEmbeddingModel
from langchain_google_vertexai import VertexAIEmbeddings

#参考情報の画像をベクトル化する用
model = MultiModalEmbeddingModel.from_pretrained("multimodalembedding@001")
#ベクトルインデックスのembedding_function として渡す用
embeddings = VertexAIEmbeddings(model_name="multimodalembedding@001")

ベクトルインデックスを作成

テスト目的なので、保存場所はローカルのメモリ内。multimodalembedding@001 が1408次元のベクトル空間に埋め込むことに注意。

from faiss import IndexFlatL2
from langchain_community.vectorstores import FAISS
from langchain_community.docstore.in_memory import InMemoryDocstore

vector_store = FAISS(
    embedding_function=embeddings, 
    index=IndexFlatL2(1408),
    docstore=InMemoryDocstore(),
    index_to_docstore_id={},
)

参考情報の画像をベクトル化

import glob
from vertexai.vision_models import Image

#画像を保存したパスを定義
IMAGE_FOLDER_PATH = "./docs/"

#IMAGE_FOLDER_PATH 配下のpng ファイルを取得
image_path_list = glob.glob(IMAGE_FOLDER_PATH + '*.png')
image_list = [Image.load_from_file(image_path) for image_path in image_path_list]

#画像をベクトルに変換
image_embedding_list = [model.get_embeddings(image=image).image_embedding for image in image_list]

ベクトルインデックスに格納、Retriever 作成

#画像のパスとベクトルを紐づけ
image_embedding_pairs = list(zip(image_path_list, image_embedding_list))

#ベクトルインデックスに格納
vector_store.add_embeddings(image_embedding_pairs, metadatas=[{"type": "image"} for _ in image_path_list])

#Retriever を作成
retriever = vector_store.as_retriever()

マルチモーダルLLM へ渡すプロンプト作成

LLM もGoogle のGemini 1.5 Pro を使う。いつかの仕様変更で、画像はBase64 にエンコードしてGemini に入力する必要あり。

import base64
from langchain.schema.messages import HumanMessage

#png の画像ファイルをBase64 エンコードされた文字列に変換
def image_to_base64(image_path: str) -> str:
    with open(image_path, "rb") as image_file:
        encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
    return f"data:image/png;base64,${encoded_string}"
    

#プロンプト生成
def generate_prompt(data: dict) -> HumanMessage:
    #Retriever の出力は画像と文章が混ざっているので、分離する
    image_list, text_list = [], []
    for doc in data["context"]:
        if 'type' in doc.metadata and doc.metadata['type'] == 'image':
            image_list.append(doc)
        else:
            text_list.append(doc)

    prompt_template = f"""
        次のコンテキストと入力画像のみに基づいて、日本語で回答してください。

        質問:
        {data["question"]}

        context:
        {text_list}
    """
  
    return [
        HumanMessage(
            content=
                [{"type": "text", "text": prompt_template}] + 
                [{"type": "image_url", "image_url": {"url": image_to_base64(image.page_content)}} for image in image_list]
        )
    ]

chain 作成

from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
from langchain.schema.output_parser import StrOutputParser
from langchain_google_vertexai import ChatVertexAI

llm = ChatVertexAI(model_name="gemini-1.5-pro")

chain = (
    {"context": retriever, "question": RunnablePassthrough()}
    | RunnableLambda(generate_prompt)
    | llm
    | StrOutputParser()
)

print(chain.invoke("質問文"))

Discussion