画像間の類似度評価指標:FCNスコア(pixel-level-accuracy score)
はじめに
GAN(Generative Adversarial Network)を勉強していて、「生成した画像がGround Truthの画像と似ているかどうかってどう定量的に評価するんだ?」と疑問に思いまして、自身の理解を確かめるために実装してみます。
FCN-Score
pix2pixやCycleGANの論文で評価指標の一つとして使われた指標です。SemanticSegmentation向けに学習されたFCN(Fully Connected Networks)アーキテクチャを持つモデルに対して、Ground Truthの画像(
タスクに応じてFCNを再学習してから使うべきかもしれませんが、本質的にはSegmentationの正しさは重要ではなく、両者のSegmentation結果がどのくらい等しいかどうかが重要と考えます。また学習済みのFCNモデルはpytorchから利用可能です。これらを踏まえて以下実装していきます。
Segmentation向け学習済みFCNモデルの取得
pytorchから取得できます。
import torch, torchvision
from PIL import Image
import cv2
import matplotlib.pyplot as plt
import japanize_matplotlib
import albumentations as A
from sklearn.metrics import accuracy_score
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 学習済みのFCNモデルを取得
model = torchvision.models.segmentation.fcn_resnet50(pretrained=True)
model.to(device)
画像の準備
何となくSEM画像を持ってきて、自分の環境にsem_sample.jpg
として保存しておき読み込みます。
y = np.array(Image.open("sem_sample.jpg").convert("RGB"))
古典的なGANで生成された画像は若干ぼやける性質があるので、opencv使ってわざとぼやけさせた画像を用意します。ksizeに与える値が大きいほどぼやけた画像になります。詳しい使い方は割愛します。
y_hat = cv2.blur(image, ksize=(15,15))
比較してみます。
plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plt.title("元画像($y$)")
plt.axis("off")
plt.imshow(y)
plt.subplot(1,2,2)
plt.title("ぼやけさせた画像($\hat{y}$)")
plt.axis("off")
plt.imshow(y_hat)
plt.show()
FCNスコアの計算
モデルに画像データを渡すための前処理を施します。
やっていることは、画像を256×256にリサイズ、標準化して、Height×Width×Channels
のnumpy.array
から、1×Channels×Height×Width
のtorch.Tensor
に変換しています。
transform = A.Compose([
A.Resize(256,256),
A.Normalize()
])
y_t = torch.tensor(
transform(image=y)["image"]
).transpose(0,2).transpose(1,2).unsqueeze(0)
y_hat_t = torch.tensor(
transform(image=y_hat)["image"]
).transpose(0,2).transpose(1,2).unsqueeze(0)
モデルに入力し、Segmentationの結果を得ます。
output_y = model(y_t.to(device))
output_y_hat = model(y_hat_t.to(device))
seg_y = output_y["out"].squeeze().detach().argmax(0).cpu().numpy()
seg_y_hat = output_y_hat["out"].squeeze().detach().argmax(0).cpu().numpy()
Segmentation結果を表示します。
plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plt.title("$y$のSegmentation画像")
plt.imshow(seg_y, vmin=0, vmax=21)
plt.colorbar()
plt.subplot(1,2,2)
plt.title("$\hat{y}$のSegmentation画像")
plt.imshow(seg_y_hat, vmin=0, vmax=21)
plt.colorbar()
ピクセル単位でのAccuracyスコアを計算します。
accuracy_score(seg_y.flatten(), seg_y_hat.flatten())
すると、0.933502197265625
とそれなりに高い値が得られます。これは、背景(Segmentationのラベルが0のピクセル)がSegmentation結果の大部分を占め、
accuracy_score(seg_y[seg_y>0].flatten(), seg_y_hat[seg_y>0].flatten())
すると、0.18867924528301888
という結果が得られました。
最後に、
ksize_cands = np.arange(1,51,2)
acc_list = []
plt.figure(figsize=(20,8))
for i,k in enumerate(ksize_cands):
# y_hatを作る
y_hat = cv2.blur(image, ksize=(k,k))
# FCNモデルにわたす画像への前処理
transform = A.Compose([
A.Resize(256,256),
A.Normalize()
])
y_t = torch.tensor(
transform(image=y)["image"]
).transpose(0,2).transpose(1,2).unsqueeze(0)
y_hat_t = torch.tensor(
transform(image=y_hat)["image"]
).transpose(0,2).transpose(1,2).unsqueeze(0)
# FCNモデルへの画像入力
output_y = model(y_t.to(device))
output_y_hat = model(y_hat_t.to(device))
# Segmentation結果を得る
seg_y = output_y["out"].squeeze().detach().argmax(0).cpu().numpy()
seg_y_hat = output_y_hat["out"].squeeze().detach().argmax(0).cpu().numpy()
# ピクセル単位でのaccuracyスコアを得る
acc = accuracy_score(seg_y[seg_y>0].flatten(), seg_y_hat[seg_y>0].flatten())
acc_list.append(acc)
if (i+1) % 5 == 0:
plt.subplot(2,5,(i+1) // 5)
plt.title("$\hat{y},$"+f" ksize=({k},{k})")
plt.imshow(y_hat)
plt.axis("off")
plt.subplot(2,5,(i+1) // 5 + 5)
plt.title("$seg(\hat{y}),$"+f" ksize=({k},{k})")
plt.imshow(seg_y_hat, vmin=0, vmax=21)
plt.colorbar()
plt.axis("off")
plt.tight_layout()
plt.show()
plt.rcParams["font.size"] = 14
plt.plot(ksize_cands, acc_list)
plt.xlabel("kernel size")
plt.ylabel("pixel accuracy of FCN score")
plt.grid()
plt.show()
ぼやければぼやけるほどピクセルレベルでのAccuracyスコアが減少していく傾向を確認できました。
本当は、インスタンス単位でのAccuracyやIoUでもFCNスコアは評価されるべきですが、今回はピクセルレベルでのAccuracyスコアとしてFCNスコアの計算方法までの紹介とします。画像と画像の類似度を評価できる手法としてそこまで難しくない実装で実現できるところが良いと思います。
以上
Discussion