DreamSimをgoogle colabで試してみた

2023/06/27に公開

DreamSimとは

DreamSimは新しいVisual similarityの基盤モデルです。
https://github.com/ssundaram21/dreamsim

リンク

Colab
github

準備

Google Colabを開き、メニューから「ランタイム→ランタイムのタイプを変更」でランタイムを「GPU」に変更します。

環境構築

インストール手順です。

!pip install dreamsim

比較モデルのダウンロードもします。

!mkdir models/
!wget -O models/open_clip_vitb32_pretrain.pth.tar https://github.com/ssundaram21/dreamsim/releases/download/v0.1.0/open_clip_vitb32_pretrain.pth.tar

推論

(1)デモデータの準備

!mkdir /content/images
!wget https://github.com/ssundaram21/dreamsim/releases/download/v0.1.0/sample_images.zip -O images/sample_images.zip
!wget https://github.com/ssundaram21/dreamsim/releases/download/v0.1.0/retrieval_images.zip -O images/retrieval_images.zip
!unzip images/sample_images.zip
!unzip images/retrieval_images.zip

(2)モデルのロード

import sys
import torch
from dreamsim import dreamsim

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = dreamsim(pretrained=True)

(3)Utils function

import matplotlib.pyplot as plt

def show_imgs(ims, captions=None):
    fig, ax = plt.subplots(nrows=1, ncols=len(ims), figsize=(10, 5))
    for i in range(len(ims)):
        ax[i].imshow(ims[i])
        ax[i].axis('off')
        if captions is not None:
          ax[i].set_title(captions[i], fontweight="bold")

(4) Similarity Search

from PIL import Image
import torch
import matplotlib.pyplot as plt

ref_pil = Image.open("sample_images/ref_1.png")
img_a_pil = Image.open("sample_images/img_a_1.png")
img_b_pil = Image.open("sample_images/img_b_1.png")

show_imgs(
    ims=[img_a_pil, ref_pil, img_b_pil],
    captions=["A", "Reference", "B"])
# calculate similarity score
ref = preprocess(ref_pil).to(device)
img_a = preprocess(img_a_pil).to(device)
img_b = preprocess(img_b_pil).to(device)

dist_a = model(ref, img_a)
dist_b = model(ref, img_b)

show_imgs(
    ims=[img_a_pil, ref_pil, img_b_pil],
    captions=[f"A, Score: {round(float(dist_a.cpu()), 3)}",
              "Reference",
              f"B, Score: {round(float(dist_b.cpu()), 3)}"])

Advanced Application

(1)別のデータで試す

ref_path = "sample_images/ref_2.png" #@param {type:"string"}
img_a_path = "sample_images/img_a_2.png" #@param {type:"string"}
#@markdown Optional:
img_b_path = "sample_images/img_b_2.png" #@param {type:"string"}

ref_pil = Image.open(ref_path)
img_a_pil = Image.open(img_a_path)
ref = preprocess(ref_pil).to(device)
img_a = preprocess(img_a_pil).to(device)
dist_a = model(ref, img_a)

if len(img_b_path) > 0:
  img_b_pil = Image.open(img_b_path)
  img_b = preprocess(img_b_pil).to(device)
  dist_b = model(ref, img_b)
  ims = [img_a_pil, ref_pil, img_b_pil]
  captions = [f"A, Score: {round(float(dist_a.cpu()), 3)}", "Reference",
              f"B, Score: {round(float(dist_b.cpu()), 3)}"]
else:
  ims = [ref_pil, img_a_pil]
  captions = ["Reference", f"Score: {round(float(dist_a.cpu()), 3)}"]

show_imgs(
    ims=ims,
    captions=captions)

(2) Image Retrieval

import os
from PIL import Image
from tqdm import tqdm
import torch.nn.functional as F
import pandas as pd
import pickle

root = "retrieval_images/"
images = []
for path in os.listdir(root):
  try:
    images.append(Image.open(root + path))
  except:
    pass
query, images = images[0], images[1:]

# comparison with 3 models
from dreamsim import PerceptualModel

dreamsim_model = model
dino_model = PerceptualModel(feat_type='cls', model_type='dino_vitb16', stride='16', baseline=True, device="cuda")
open_clip_model = PerceptualModel(feat_type='embedding', model_type='open_clip_vitb32', stride='32', baseline=True, device="cuda")

def get_embeddings(model, name, images):
  embeddings = []
  for img in tqdm(images):
    img = preprocess(img).to(device)
    embeddings.append(model.embed(img).detach().cpu())
  with open(f"images/{name}_embeds.pkl", "wb") as f:
    pickle.dump(embeddings, f)

get_embeddings(dreamsim_model, "dreamsim", images)
get_embeddings(dino_model, "dino", images)
get_embeddings(open_clip_model, "open_clip", images)

def nearest_neighbors(embeddings, query_index):
    query_embed = embeddings[query_index]
    dists = {}

    # Compute the (cosine) distance between the query embedding
    # and each search image embedding
    for i, im in enumerate(embeddings):
      if i == query_index:
        continue
      dists[i] = (1 - F.cosine_similarity(query_embed, embeddings[i],
                                          dim=-1)).item()

    # Return results sorted by distance
    df = pd.DataFrame({"ids": list(dists.keys()), "dists": list(dists.values())})
    df = df.sort_values(by="dists")
    return df
    
query_index = 9 #@param {type:"number"}
n = 3
display_width = 11
display_height = 4

## Load embeddings for each metric and compute nearest neighbors to the query_index-th image
nn_dfs = {}
for metric_name in ["dreamsim", "open_clip", "dino"]:
    with open(f"images/{metric_name}_embeds.pkl", "rb") as f:
      embeddings = pickle.load(f)
    nn_dfs[metric_name] = nearest_neighbors(embeddings, query_index)

## Plot results
f, ax = plt.subplots(4, n+2, figsize=(14,7), gridspec_kw={"height_ratios":[0.005,1,1,1]})
ax[0,0].axis('off')
for col in range(1, n+2):
    title = "Query" if col == 1 else f"n{col-1}"
    ax[0, col].set_title(title, fontweight="bold", fontsize=15)
    ax[0, col].axis('off')

for i, name in enumerate(["dreamsim", "open_clip", "dino"]):
    ax[i+1, 0].text(0.5, 0.5, name, fontsize=13)
    ax[i+1, 0].axis('off')

    ax[i+1, 1].imshow(images[query_index])
    ax[i+1, 1].axis("off")

    for j in range(n):
        im_idx = nn_dfs[name]['ids'].iloc[j]
        ax[i + 1, j + 2].imshow(images[im_idx])
        ax[i + 1, j + 2].axis('off')
plt.tight_layout()


最後に

今回は新しいVisual Similarity SearchモデルであるDreamSimをgoogle colabで試してみました。なんか構図的なところをよく学習している気がする。一つ一つのObjectの検出自体はClipとかDinoとは変わりませんが、複数Objectの関連性を含めたVisual Similarityはよく学習できていると思います。

今後ともLLM, Diffusion model, Image Analysis, 3Dに関連する試した記事を投稿していく予定なのでよろしくお願いします。

Discussion