⛳
DreamSimをgoogle colabで試してみた
DreamSimとは
DreamSimは新しいVisual similarityの基盤モデルです。
リンク
準備
Google Colabを開き、メニューから「ランタイム→ランタイムのタイプを変更」でランタイムを「GPU」に変更します。
環境構築
インストール手順です。
!pip install dreamsim
比較モデルのダウンロードもします。
!mkdir models/
!wget -O models/open_clip_vitb32_pretrain.pth.tar https://github.com/ssundaram21/dreamsim/releases/download/v0.1.0/open_clip_vitb32_pretrain.pth.tar
推論
(1)デモデータの準備
!mkdir /content/images
!wget https://github.com/ssundaram21/dreamsim/releases/download/v0.1.0/sample_images.zip -O images/sample_images.zip
!wget https://github.com/ssundaram21/dreamsim/releases/download/v0.1.0/retrieval_images.zip -O images/retrieval_images.zip
!unzip images/sample_images.zip
!unzip images/retrieval_images.zip
(2)モデルのロード
import sys
import torch
from dreamsim import dreamsim
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = dreamsim(pretrained=True)
(3)Utils function
import matplotlib.pyplot as plt
def show_imgs(ims, captions=None):
fig, ax = plt.subplots(nrows=1, ncols=len(ims), figsize=(10, 5))
for i in range(len(ims)):
ax[i].imshow(ims[i])
ax[i].axis('off')
if captions is not None:
ax[i].set_title(captions[i], fontweight="bold")
(4) Similarity Search
from PIL import Image
import torch
import matplotlib.pyplot as plt
ref_pil = Image.open("sample_images/ref_1.png")
img_a_pil = Image.open("sample_images/img_a_1.png")
img_b_pil = Image.open("sample_images/img_b_1.png")
show_imgs(
ims=[img_a_pil, ref_pil, img_b_pil],
captions=["A", "Reference", "B"])
# calculate similarity score
ref = preprocess(ref_pil).to(device)
img_a = preprocess(img_a_pil).to(device)
img_b = preprocess(img_b_pil).to(device)
dist_a = model(ref, img_a)
dist_b = model(ref, img_b)
show_imgs(
ims=[img_a_pil, ref_pil, img_b_pil],
captions=[f"A, Score: {round(float(dist_a.cpu()), 3)}",
"Reference",
f"B, Score: {round(float(dist_b.cpu()), 3)}"])
Advanced Application
(1)別のデータで試す
ref_path = "sample_images/ref_2.png" #@param {type:"string"}
img_a_path = "sample_images/img_a_2.png" #@param {type:"string"}
#@markdown Optional:
img_b_path = "sample_images/img_b_2.png" #@param {type:"string"}
ref_pil = Image.open(ref_path)
img_a_pil = Image.open(img_a_path)
ref = preprocess(ref_pil).to(device)
img_a = preprocess(img_a_pil).to(device)
dist_a = model(ref, img_a)
if len(img_b_path) > 0:
img_b_pil = Image.open(img_b_path)
img_b = preprocess(img_b_pil).to(device)
dist_b = model(ref, img_b)
ims = [img_a_pil, ref_pil, img_b_pil]
captions = [f"A, Score: {round(float(dist_a.cpu()), 3)}", "Reference",
f"B, Score: {round(float(dist_b.cpu()), 3)}"]
else:
ims = [ref_pil, img_a_pil]
captions = ["Reference", f"Score: {round(float(dist_a.cpu()), 3)}"]
show_imgs(
ims=ims,
captions=captions)
(2) Image Retrieval
import os
from PIL import Image
from tqdm import tqdm
import torch.nn.functional as F
import pandas as pd
import pickle
root = "retrieval_images/"
images = []
for path in os.listdir(root):
try:
images.append(Image.open(root + path))
except:
pass
query, images = images[0], images[1:]
# comparison with 3 models
from dreamsim import PerceptualModel
dreamsim_model = model
dino_model = PerceptualModel(feat_type='cls', model_type='dino_vitb16', stride='16', baseline=True, device="cuda")
open_clip_model = PerceptualModel(feat_type='embedding', model_type='open_clip_vitb32', stride='32', baseline=True, device="cuda")
def get_embeddings(model, name, images):
embeddings = []
for img in tqdm(images):
img = preprocess(img).to(device)
embeddings.append(model.embed(img).detach().cpu())
with open(f"images/{name}_embeds.pkl", "wb") as f:
pickle.dump(embeddings, f)
get_embeddings(dreamsim_model, "dreamsim", images)
get_embeddings(dino_model, "dino", images)
get_embeddings(open_clip_model, "open_clip", images)
def nearest_neighbors(embeddings, query_index):
query_embed = embeddings[query_index]
dists = {}
# Compute the (cosine) distance between the query embedding
# and each search image embedding
for i, im in enumerate(embeddings):
if i == query_index:
continue
dists[i] = (1 - F.cosine_similarity(query_embed, embeddings[i],
dim=-1)).item()
# Return results sorted by distance
df = pd.DataFrame({"ids": list(dists.keys()), "dists": list(dists.values())})
df = df.sort_values(by="dists")
return df
query_index = 9 #@param {type:"number"}
n = 3
display_width = 11
display_height = 4
## Load embeddings for each metric and compute nearest neighbors to the query_index-th image
nn_dfs = {}
for metric_name in ["dreamsim", "open_clip", "dino"]:
with open(f"images/{metric_name}_embeds.pkl", "rb") as f:
embeddings = pickle.load(f)
nn_dfs[metric_name] = nearest_neighbors(embeddings, query_index)
## Plot results
f, ax = plt.subplots(4, n+2, figsize=(14,7), gridspec_kw={"height_ratios":[0.005,1,1,1]})
ax[0,0].axis('off')
for col in range(1, n+2):
title = "Query" if col == 1 else f"n{col-1}"
ax[0, col].set_title(title, fontweight="bold", fontsize=15)
ax[0, col].axis('off')
for i, name in enumerate(["dreamsim", "open_clip", "dino"]):
ax[i+1, 0].text(0.5, 0.5, name, fontsize=13)
ax[i+1, 0].axis('off')
ax[i+1, 1].imshow(images[query_index])
ax[i+1, 1].axis("off")
for j in range(n):
im_idx = nn_dfs[name]['ids'].iloc[j]
ax[i + 1, j + 2].imshow(images[im_idx])
ax[i + 1, j + 2].axis('off')
plt.tight_layout()
最後に
今回は新しいVisual Similarity SearchモデルであるDreamSimをgoogle colabで試してみました。なんか構図的なところをよく学習している気がする。一つ一つのObjectの検出自体はClipとかDinoとは変わりませんが、複数Objectの関連性を含めたVisual Similarityはよく学習できていると思います。
今後ともLLM, Diffusion model, Image Analysis, 3Dに関連する試した記事を投稿していく予定なのでよろしくお願いします。
Discussion