🐶

シンプルな画像検索システムを作る

2022/01/20に公開

はじめに

最近画像検索の手法の論文を読んだので、すごく簡易なバージョンを作ってみたいなーと思った。

特に具体的な応用を考えているわけでは無いが、とにかく手法は試してみたい。そういうわけで、手軽に手に入る自炊済み書籍のPDFの特定のページを検索できるようなモデルを作ってみたいと思った。


手軽に手に入る自炊済み書籍

いちおう、最終的にはiPhoneで撮影した物理書籍の特定のページの画像を元に、自炊済みの特定のページを検索できることが確認できた[1]。iPhoneアプリやAPIを作ったわけではなく、画像検索のコアの部分のみしか作っていないという点はご了承いただきたい。

データの準備

PDFの画像化

というわけで、早速PDFのページを分割し、画像化する。pdf2imageというパッケージが簡単に使えそうだったので、使用してみた。

120個のPDFファイルを対象とし、45,622枚のJPEG画像を得ることができた。

PDFページの画像化
import glob
import hashlib
import os

import tqdm
from pdf2image import convert_from_path

pdfs = glob.glob("data/pdfs/*.pdf")
os.makedirs("data/images", exist_ok=True)
for i, fpath in tqdm.tqdm(enumerate(pdfs)):
    fname = os.path.basename(fpath)
    hash = hashlib.sha256(fname.encode())
    hash = hash.hexdigest()[:5]  # PDFを見分けるためのハッシュ値

    # すでにPDFが処理済みならスキップ
    if len(list(glob.glob(f"data/images/{hash}_*.jpg"))) > 0:
        continue
    pages = convert_from_path(fpath, thread_count=8)
    for j, page in enumerate(pages):
        output_path = f"data/images/{hash}_{str(j).zfill(4)}.jpg"
        # すでにJPEGが吐かれているならスキップ
        if os.path.exists(output_path):
            continue
        else:
            page.save(output_path, "JPEG")

Datasetクラスの作成

PDFから抽出したJPEG画像を検索対象としたいが、一部は後述する特徴抽出器の学習時に使用するTrain/Valセットとしても使用する。
検索対象は120冊の書籍の全画像45,622枚、Trainは15冊の書籍から得られる5,655枚、Valは5冊の書籍から得られる2,309枚を対象とした。

データ拡張

今回の実験の中で一番の反省点はデータ拡張である。最終的な具体的な利用シーンをあまり考えていなかったということもあり、少々雑に決めてしまった。
なんとなく、射影変換、画像の明暗の変換、ぼかし、標準的なリサイズ・ランダムクロップなどは入れてみた。今回対象とする画像は文字が記載されている本のページなので、左右反転などは含めていない。

データ拡張を適用すると、同一の画像から以下のようなバリエーションの画像が得られる。

データセットとデータ拡張
class PdfPageDataset(Dataset):
    def __init__(self, root_dir: str, split: str, augmentation: bool = True) -> None:
        re_fname = re.compile(r"(\S{5})_(\d{4})\.jpg$")

        pages_by_book = defaultdict(list)
        for fpath in glob.glob(os.path.join(root_dir, "*.jpg")):
            book_id, _ = re_fname.search(fpath).groups()
            pages_by_book[book_id].append(fpath)

        books = sorted(pages_by_book.keys())
        if split == "train":
            books = books[:15]
        elif split == "val":
            books = books[15:20]
        elif split == "predict":
            books = books
        else:
            raise Exception(f"Split {split} is not defined.")

        if split == "train" or augmentation:
            self.transform = Compose(
                [
                    RandomPerspective(),
                    ColorJitter(brightness=.5, hue=.3),
                    RandomResizedCrop(320, scale=(0.2, 1)),
                    RandomApply([GaussianBlur(5), GaussianBlur(11)], p=0.5),
                    ToTensor(),
                    Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
                ]
            )
        else:
            self.transform = Compose(
                [Resize(480), ToTensor(), Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),]
            )

        self.data = sorted(
            itertools.chain.from_iterable([pages_by_book[book_id] for book_id in books])
        )

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.LongTensor]:
        fpath = self.data[index]
        img = Image.open(fpath).convert("RGB")
        img = self.transform(img)
        return img, index

画像検索の仕組み

続いて、画像検索のメイン部分を構築していく。画像検索の仕組みは単純で、画像から特徴量を抽出する特徴抽出器と、画像集合の中から近傍の特徴量を検索する検索器とからなる。
特徴抽出は安心と信頼のResNetを使い、検索はfaissのアルゴリズムの中から最も適当そうなものを選ぶことにした。

特徴抽出器

特徴抽出器は、ImageNetで学習済みのResNet18をベースとし、損失関数としてArcCosを採用して訓練してみた。これにより、あるページと他のページを見分けることができるような画像特徴量を抽出できるモデルが訓練できると期待した。

ArcCos損失
class ArcMargineLoss(nn.Module):
    def __init__(self, num_classes: int, dim: int, margin: float, scale: float) -> None:
        super().__init__()

        self.scale = scale
        self.margin = margin
        self.weight = nn.Parameter(torch.empty(num_classes, dim))  # [n, d]
        self.num_classes = num_classes
        nn.init.xavier_uniform_(self.weight)

    def forward(self, feature: torch.Tensor, label: torch.LongTensor) -> torch.Tensor:
        norm_weight = F.normalize(self.weight)  # [n, d]
        norm_feature = F.normalize(feature)  # [N, d]
        cos = torch.einsum("Nd,nd->Nn", norm_feature, norm_weight)  # [N, n]
        theta = torch.acos(cos)

        m_exp = torch.exp(self.scale * torch.cos(theta + self.margin))  # [N, n]
        exp = torch.exp(self.scale * cos)  # [N, n]
        one_hot = F.one_hot(label, self.num_classes)  # [N, n]

        loss = -torch.log(
            (one_hot * m_exp).sum(dim=1)
            / (one_hot * m_exp + (1 - one_hot) * exp).sum(dim=1)
        ).mean()

        return loss

DeepHash

画像特徴量は、高速に検索するためにバイナリ化したい。今回はDSHSDを適用する。仰々しい名前がついているが、特徴抽出器の全結合層の最後の活性化関数をTanhにして普通に訓練するだけで良い。訓練中は、出力の各要素は-1から1の間の連続値となるが、実際に検索に使用する際は、出力の各要素の値の正負によって二値化する。

特徴抽出器
class FeatureExtractor(nn.Sequential):
    def __init__(self, dim: int, binarize: bool = True) -> None:
        base_model = torchvision.models.resnet18(pretrained=True)
        feature = nn.Sequential(*list(base_model.children())[:-1])
        flatten = nn.Flatten()
        fc = nn.Linear(512, dim)
        if binarize:
            act = nn.Tanh()
        else:
            act = nn.Identity()

        super().__init__(feature, flatten, fc, act)

特徴抽出器の訓練

特徴抽出器の訓練は、結構大雑把にやっており、特にハイパラチューニングなどもしていない。

validation精度の測定のためには、その時点でのモデルを用いてvalidationデータセットの全ての画像(変形なし)から特徴抽出をあらかじめ得ておく必要がある。少々時間を食う処理になっているので、validationは5エポック毎に行うことにした。

全体的に、以下のようにPyTorch Lightningで普通に書いている。

Lightning Module
class LitImageSearchModel(pl.LightningModule):
    def __init__(
        self, num_classes: int, dim: int, binarize: bool, margin: float, scale: float
    ):
        super().__init__()
        self.save_hyperparameters()
        self.model = FeatureExtractor(dim, binarize)
        self.loss_func = ArcMargineLoss(num_classes, dim, margin, scale)

    def training_step(self, batch, batch_idx):
        x, y = batch
        feature = self.model(x)
        loss = self.loss_func(feature, y)
        self.log("train_loss", loss)
        return loss

    def on_validation_epoch_start(self) -> None:
        """Garellyを登録する
        """
        self.model.eval()
        val_garelly = []
        with torch.no_grad():
            for x, _ in self.trainer.datamodule.test_dataloader():
                x = x.to(self.device)
                val_garelly.append(self.model(x))
        val_garelly = torch.cat(val_garelly, dim=0)  # [n, d]
        self.val_garelly = F.normalize(val_garelly).detach()

    def validation_step(self, batch, batch_idx):
        x, y = batch
        out = self.model(x)  # [N, d]
        out = F.normalize(out)

        cos = torch.einsum("Nd,nd->Nn", out, self.val_garelly)  # [N, n]
        pred = torch.argmax(cos, dim=1)  # [N,]
        acc = (pred == y).sum().item() / y.shape[0]

        self.log("val_acc", acc)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            [
                {"params": self.model.parameters(), "lr": 1e-5},
                {"params": self.loss_func.parameters()},
            ],
            lr=1e-3,
            weight_decay=1e-4,
        )
        return optimizer

GTX 1080Tiで24時間程度訓練させ、以下のような感じで損失やvalidation精度は推移した。validation精度はもう少し伸びそうだったが、適当に切り上げた。

最近傍探索

特徴抽出器が訓練できたので、検索の方に話を移す。
クエリ画像の特徴量は訓練済みの特徴抽出器の出力を用いて、その正負をもとに得られる。faissのバイナリ検索を使用するためには、特徴量はbitではなくbyteにしてあげる必要がある。
np.packbitsを用いると、256bitのbit特徴量は32byteのbyte特徴量に変換できる。

特徴量の抽出とバイナリ化
def extract_feature(image: Image, transform: Compose) -> np.ndarray:
    x = transform(image)  # 単にpytorchのTensorに変換しているだけ
    with torch.no_grad():
        feat = model(x.unsqueeze(0).to(device)).cpu() > 0  # 特徴抽出と正負による二値化
    return np.packbits(feat.numpy(), axis=1)  # byte特徴量へ変換

事前に、検索対象の45,622枚の画像に対しても同様に32byteの特徴量を抽出しておく。全体としてはたかだか1.4MBというわずかなメモリに収まり、特に凝った検索アルゴリズムを使用する必要はないので、全件検索を行うfaiss.IndexBinaryFlatを使用した。

検索対象の画像からIndexを構築
index = faiss.IndexBinaryFlat(256)  # indexの初期化
db = np.packbits(gallery.numpy(), axis=1)  # bitデータをbyte単位にまとめる
index.add(db)  # indexに格納

以上を準備しておくと、検索は以下のように書ける。

feat = extract_feature(query_path, dataset.transform)
dist, result = index.search(feat, k=5)

実行速度的には、特徴抽出は0.3から0.6秒程度で終わる。検索は0.0003秒くらいで完了し、ちょう速い。

実際の画像で確認

validation精度(2,309枚に対するTop1精度)は85%程度とまあまあ良ささそうに見える。しかし、この精度は自炊済みの比較的きれいな画像をクエリとした時の精度である。そのため、例えばスマホなどで撮影した写真をクエリとし、その画像がどの本の何ページなのかを特定するといった用途に使えるかは疑問が残る。

そこで、自炊済みでありかつ物理的な書籍としても手元にあるという本の適当なページをスマホで撮影し、その画像をクエリとしてみたときにページを特定できるか、ということをやってみた。
以下がクエリ画像の例である(念の為に紙面が読み取れないよう加工を施しています)。

物理的な書籍を撮影する場合、紙面上の陰影や湾曲、手や背景などの無関係な領域の写り込み、場合によっては書き込みなどがあり、案の定この画像から直接得られる特徴量を用いて検索を行っても、うまくいかない。本来であれば、もっと事前にユースケースを具体的に想像し、現実的なデータ拡張を検討することで対応できるようにした方が良いのだろうが、今回はクエリ画像に対して簡単な前処理を施すことで対応してみた。

今回の自炊済み書籍の画像は、表紙や挿絵を除き、そもそも二値またはグレイスケールの画像がほとんどである。そのため、クエリ画像はグレイスケールに変換し、opencvのfastNlMeansDenoisingを適用してノイズ除去することにした。

クエリ画像の前処理
image = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
if image is None:
    raise FileNotFoundError(path)
image = cv2.resize(image, (480, int(480 * image.shape[0] / image.shape[1])))
image = cv2.fastNlMeansDenoising(image, h=10)
return Image.fromarray(image).convert("RGB")

こうすることで、上に挙げた3枚のクエリ画像については、まあまあ満足いく結果が得られた。
以下が検索を行った時のTop5で、数字は全検索対象の画像のidを表しており、太字になっているのが正解のidである。
2つ目(上図の真ん中)は背景領域が多いことが悪影響を及ぼしている可能性があるので、main body detectionのような前処理も適用してあげればTop1にできるかもしれない。

  • query1: [18699 15075 28352 34068 34036]
  • query2: [28309 34036 28391 18767 28416]
  • query3: [18835 18735 18742 18767 12469]

おわりに

シンプルな画像検索を作ってみた。想像していたよりもちゃんと動くものができた。一応、githubのリポジトリも置いときますね。

データ拡張やクエリ画像に対する前処理が課題として残っているので、仕事なんかで使う場合はこの辺りを詰める必要があるんだろうなと思った。また、faissまわりのチューニングなどは、実用上はもっとちゃんと調べてやる必要がありそうだ。

脚注
  1. 同じ本を2冊購入してしまったときに、片方を自炊し、もう片方を手元に残していた。 ↩︎

Discussion