Zenn
Closed5

ColPali: PaliGemma-3BとColBERTストラテジーに基づくビジュアルレトリバー

kun432kun432

これで知った
https://twitter.com/jerryjliu0/status/1815904500491972663

HuggingFaceで公開されている
https://huggingface.co/vidore/colpali

ColPaliは、視覚言語モデル(VLM)に基づく新しいモデルアーキテクチャと学習戦略に基づいて、視覚的特徴から効率的に文書をインデックス化するモデルである。PaliGemma-3Bの拡張で、テキストと画像のColBERTスタイルのマルチベクトル表現を生成する。論文 ColPali で紹介された: Efficient Document Retrieval with Vision Language Models で紹介され、このリポジトリで初めて公開された。

論文
https://arxiv.org/abs/2407.01449

GitHubレポジトリ
https://github.com/illuin-tech/colpali

マルチモーダルモデルを使ったRAGはLlamaIndexのドキュメントにも色々あるけど、retrievalでColBERTを使うってのがポイントなのではなかろうか。

ちょっとRagatouille触ったところでColBERTに興味が出てきたのと、あとPaliGemma全く触ってないのもあって、少し触ってみたいところ。

kun432kun432

どうやら近々LlamaIndexのWebinarがある様子。LlamaIndex側にはドキュメントもnotebookもまだ存在しないみたい。

kun432kun432

少し前にByaldiというラッパーを使って試していたのだけど、

https://zenn.dev/kun432/scraps/5878459758bfdd

以下を試している際にColPaliのCookbookがあることを知った。

https://zenn.dev/kun432/scraps/2145a851102507

上記を試した後に、やはりネイティブなColPaliライブラリにも触れておきたいと思ったので、改めて基本的なところから試すことにする。

以下も参考にさせていただく。

https://zenn.dev/yumefuku/articles/pdf-search-colqwen2

kun432kun432

Quick Start

Colaboratoryで。ランタイムはT4。

colpali-engineをインストール。インストール後に多分ランタイム再起動が求められるので再起動。

!pip install colpali-engine
!pip freeze | grep -i colpali
出力
colpali_engine==0.3.4

検索に使用する画像はColPaliのCookbookにある画像を日本語に修正した物を使用する。

画像1

画像2

画像を処理するためのユーティリティを用意

from io import BytesIO
from PIL import Image
import requests
from typing import List


def load_image_from_url(url: str) -> Image.Image:
    """
    有効なURLからPILイメージをロードする
    """
    response = requests.get(url)
    return Image.open(BytesIO(response.content))


def scale_image(image: Image.Image, new_height: int = 1024) -> Image.Image:
    """
    アスペクト比を維持しながら画像を新しい高さにスケーリングする
    """
    width, height = image.size
    aspect_ratio = width / height
    new_width = int(new_height * aspect_ratio)

    scaled_image = image.resize((new_width, new_height))

    return scaled_image

ColPaliモデルをロードして、画像・テキストのEmbeddingを生成、類似性を計算する。なお、ColPaliモデルはPaliGemmaがベースになっているが、このモデルは日本語での学習が行われていないため、Qwen2-VLをベースにしたColQwenを使用した。

import torch

from colpali_engine.models import ColQwen2, ColQwen2Processor

model_name = "vidore/colqwen2-v1.0"

model = ColQwen2.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="cuda:0",  # or "mps" if on Apple Silicon
).eval()

processor = ColQwen2Processor.from_pretrained(model_name)

# 入力
images: List[Image.Image] = [
    load_image_from_url(
        # 2019年の燃料別平均発電量の時間帯推移に関する画像
        "https://raw.githubusercontent.com/kun432/colpali-cookbooks-jp-files/refs/heads/main/data/energy_electricity_generation_ja.jpg"
    ),
    load_image_from_url(
        # カザフスタンの油田の生産の歴史の画像
        "https://raw.githubusercontent.com/kun432/colpali-cookbooks-jp-files/refs/heads/main/data/shift_kazakhstan_ja.jpg"
    ),
]
images = [scale_image(image, new_height=512) for image in images]

queries = [
    "カザフスタンの石油生産量のうち、海底油田の占める割合は?",
    "2019年に最も多くの電力が生成された時間帯はどの時間ですか?"
]

# 入力を処理
batch_images = processor.process_images(images).to(model.device)
batch_queries = processor.process_queries(queries).to(model.device)

# フォワードパス
with torch.no_grad():
    image_embeddings = model(**batch_images)
    query_embeddings = model(**batch_queries)

scores = processor.score_multi_vector(query_embeddings, image_embeddings)
scores

結果

出力
tensor([[12.1875, 21.3750],
        [22.8750, 13.1875]])

クエリ順にそれぞれの画像との類似度スコアが返される。1つ目の「カザフスタンの石油生産量のうち、海底油田の占める割合は?」に対しては「カザフスタンの油田の生産の歴史の画像」、2つ目の「2019年に最も多くの電力が生成された時間帯はどの時間ですか?」に対しては「2019年の燃料別平均発電量の時間帯推移に関する画像」のほうがそれぞれスコアが高いことがわかる。

わかりやすいようにtop-1を出力してみる。

for idx, query in enumerate(queries):
    retrieved_image_index = scores[idx].argmax().item()
    retrieved_image = images[retrieved_image_index]
    print(f"クエリ: {query}")
    print("検索結果:")
    display(scale_image(retrieved_image, new_height=256))
    print("----")

kun432kun432

以下を参考にさせていただいてもう少し実用的な例を試す。同じくColaboratory T4で。

https://github.com/illuin-tech/colpali/blob/tree/scripts/infer/run_inference_with_python.py

https://zenn.dev/yumefuku/articles/pdf-search-colqwen2

神戸市が公開している観光に関する統計・調査資料のうち、「令和5年度 神戸市観光動向調査結果について」のPDFを使用する。

パッケージ諸々インストール

!apt update && apt install -y poppler-utils
!pip install colpali-engine pdf2image
!pip freeze | grep -i colpali
!pip freeze | grep -i pdf2image

PDFをPILのイメージとして読み込み。

!wget https://www.city.kobe.lg.jp/documents/15123/r5_doukou.pdf
from pdf2image import convert_from_path
from PIL import Image

def scale_image(image: Image.Image, new_height: int = 1024) -> Image.Image:
    """
    アスペクト比を維持しながら画像を新しい高さにスケーリングする(確認用)
    """
    width, height = image.size
    aspect_ratio = width / height
    new_width = int(new_height * aspect_ratio)
    scaled_image = image.resize((new_width, new_height))
    return scaled_image


images = convert_from_path("r5_doukou.pdf")

# 確認
for idx, i in enumerate(images, start=1):
    print(f"Page: {idx}")
    display(scale_image(i, new_height=256))

これをEmbeddingに変換する。

from typing import List, cast

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from colpali_engine.models import ColQwen2, ColQwen2Processor
from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
from colpali_engine.utils.torch_utils import get_torch_device


device = get_torch_device("auto")
print(f"Device used: {device}")

model_name = "vidore/colqwen2-v0.1"

# モデルをロード
model = ColQwen2.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map=device,
).eval()

# プロセッサをロード
processor = cast(ColQwen2Processor, ColQwen2Processor.from_pretrained(model_name))
if not isinstance(processor, BaseVisualRetrieverProcessor):
    raise ValueError("Processor should be a BaseVisualRetrieverProcessor")

# ドキュメントをEmbeddingsに変換
dataloader = DataLoader(
    dataset=images,
    batch_size=2,
    shuffle=False,
    collate_fn=lambda x: processor.process_images(x),
)
embeddings_docs: List[torch.Tensor] = []
for batch_doc in tqdm(dataloader):
    with torch.no_grad():
        batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
        embeddings_doc = model(**batch_doc)
    embeddings_docs.extend(list(torch.unbind(embeddings_doc)))

# 作成したEmbeddingsを保存
torch.save(embeddings_docs, "./embedding_docs.pt")

ではクエリ。

query = "観光客の年齢構成を教えて"

# クエリをEmbeddingに変換
processed_query = processor.process_queries([query]).to(model.device)
with torch.no_grad():
    query_embedding = model(**processed_query)

scores = processor.score_multi_vector(query_embedding, embeddings_docs)[0]
scores

各ページ(画像)ごとにスコアが計算される。

出力
tensor([13.5000, 12.5000, 12.3125, 14.0000, 11.0000, 11.1875, 11.4375, 13.8750,
        12.3125, 13.3125, 12.8125, 12.5625, 11.5000, 12.8125, 13.1875, 10.7500,
        11.6875, 10.4375, 13.0625, 12.1250, 12.7500])

スコアの上位5件を取得し、それぞれの画像を表示する。

k=5

scores_indices = scores.argsort().tolist()[-k:][::-1]
scores_indices
出力
[3, 7, 0, 9, 14]
for index in scores_indices:
    print(f"Page: {index+1} Score: {scores[index]}")
    display(scale_image(images[index], new_height=800))

年齢構成について書かれたページが1位で取得できた。

保存したEmbeddingsを読み込んで検索する場合はこんな感じで。

from typing import List, cast
import torch
from colpali_engine.models import ColQwen2, ColQwen2Processor
from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
from colpali_engine.utils.torch_utils import get_torch_device
from pdf2image import convert_from_path
from PIL import Image


def scale_image(image: Image.Image, new_height: int = 1024) -> Image.Image:
    """
    アスペクト比を維持しながら画像を新しい高さにスケーリングする
    """
    width, height = image.size
    aspect_ratio = width / height
    new_width = int(new_height * aspect_ratio)
    scaled_image = image.resize((new_width, new_height))
    return scaled_image


images = convert_from_path("r5_doukou.pdf")

device = get_torch_device("auto")
print(f"Device used: {device}")

model_name = "vidore/colqwen2-v0.1"

# モデルをロード
model = ColQwen2.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map=device,
).eval()

# プロセッサをロード
processor = cast(ColQwen2Processor, ColQwen2Processor.from_pretrained(model_name))
if not isinstance(processor, BaseVisualRetrieverProcessor):
    raise ValueError("Processor should be a BaseVisualRetrieverProcessor")

# 保存しておいたEmbeddingsをロード
embeddings_docs = torch.load("./index.pt", weights_only=True)

# クエリで検索
query = "観光客の年齢構成を教えて"
k=5

processed_query = processor.process_queries([query]).to(model.device)
with torch.no_grad():
    query_embedding = model(**processed_query)

scores = processor.score_multi_vector(query_embedding, embeddings_docs)[0]
scores_indices = scores.argsort().tolist()[-k:][::-1]

# 検索結果を表示
for index in scores_indices:
    print(f"Page: {index+1} Score: {scores[index]}")
    display(scale_image(images[index], new_height=800))
このスクラップは22日前にクローズされました
ログインするとコメントできます