👨‍👨‍👦‍👦

画像間の類似度評価指標:FCNスコア(pixel-level-accuracy score)

2023/05/29に公開

はじめに

GAN(Generative Adversarial Network)を勉強していて、「生成した画像がGround Truthの画像と似ているかどうかってどう定量的に評価するんだ?」と疑問に思いまして、自身の理解を確かめるために実装してみます。

FCN-Score

pix2pixやCycleGANの論文で評価指標の一つとして使われた指標です。SemanticSegmentation向けに学習されたFCN(Fully Connected Networks)アーキテクチャを持つモデルに対して、Ground Truthの画像(y)および生成された画像(\hat{y})でSegmentation結果を得ます。
タスクに応じてFCNを再学習してから使うべきかもしれませんが、本質的にはSegmentationの正しさは重要ではなく、両者のSegmentation結果がどのくらい等しいかどうかが重要と考えます。また学習済みのFCNモデルはpytorchから利用可能です。これらを踏まえて以下実装していきます。

Segmentation向け学習済みFCNモデルの取得

pytorchから取得できます。

import torch, torchvision
from PIL import Image
import cv2
import matplotlib.pyplot as plt
import japanize_matplotlib
import albumentations as A
from sklearn.metrics import accuracy_score

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 学習済みのFCNモデルを取得
model = torchvision.models.segmentation.fcn_resnet50(pretrained=True)
model.to(device)

画像の準備

何となくSEM画像を持ってきて、自分の環境にsem_sample.jpgとして保存しておき読み込みます。

y = np.array(Image.open("sem_sample.jpg").convert("RGB"))

古典的なGANで生成された画像は若干ぼやける性質があるので、opencv使ってわざとぼやけさせた画像を用意します。ksizeに与える値が大きいほどぼやけた画像になります。詳しい使い方は割愛します。

y_hat = cv2.blur(image, ksize=(15,15))

比較してみます。

plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plt.title("元画像($y$)")
plt.axis("off")
plt.imshow(y)

plt.subplot(1,2,2)
plt.title("ぼやけさせた画像($\hat{y}$)")
plt.axis("off")
plt.imshow(y_hat)

plt.show()

FCNスコアの計算

y\hat{y}のFCN-Scoreを比較してみます。
モデルに画像データを渡すための前処理を施します。
やっていることは、画像を256×256にリサイズ、標準化して、Height×Width×Channelsnumpy.arrayから、1×Channels×Height×Widthtorch.Tensorに変換しています。

transform = A.Compose([
    A.Resize(256,256),
    A.Normalize()
])

y_t = torch.tensor(
    transform(image=y)["image"]
).transpose(0,2).transpose(1,2).unsqueeze(0)

y_hat_t = torch.tensor(
    transform(image=y_hat)["image"]
).transpose(0,2).transpose(1,2).unsqueeze(0)

モデルに入力し、Segmentationの結果を得ます。

output_y = model(y_t.to(device))
output_y_hat = model(y_hat_t.to(device))

seg_y = output_y["out"].squeeze().detach().argmax(0).cpu().numpy()
seg_y_hat = output_y_hat["out"].squeeze().detach().argmax(0).cpu().numpy()

Segmentation結果を表示します。

plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plt.title("$y$のSegmentation画像")
plt.imshow(seg_y, vmin=0, vmax=21)
plt.colorbar()

plt.subplot(1,2,2)
plt.title("$\hat{y}$のSegmentation画像")
plt.imshow(seg_y_hat, vmin=0, vmax=21)
plt.colorbar()

ピクセル単位でのAccuracyスコアを計算します。

accuracy_score(seg_y.flatten(), seg_y_hat.flatten())

すると、0.933502197265625とそれなりに高い値が得られます。これは、背景(Segmentationのラベルが0のピクセル)がSegmentation結果の大部分を占め、y\hat{y}でそれなりに重なってしまうからです。なので、yのSegmentationラベルが1以上のピクセルのみを対象にして再計算してみます。

accuracy_score(seg_y[seg_y>0].flatten(), seg_y_hat[seg_y>0].flatten())

すると、0.18867924528301888という結果が得られました。
最後に、\hat{y}を生成する際のksizeを変えて、画像のぼやけ具合とFCNスコアのピクセルレベルでのAccuracyスコアとの関係を把握してみます。

ksize_cands = np.arange(1,51,2)
acc_list = []
plt.figure(figsize=(20,8))
for i,k in enumerate(ksize_cands):
    # y_hatを作る
    y_hat = cv2.blur(image, ksize=(k,k))
    
    # FCNモデルにわたす画像への前処理
    transform = A.Compose([
        A.Resize(256,256),
        A.Normalize()
    ])

    y_t = torch.tensor(
        transform(image=y)["image"]
    ).transpose(0,2).transpose(1,2).unsqueeze(0)

    y_hat_t = torch.tensor(
        transform(image=y_hat)["image"]
    ).transpose(0,2).transpose(1,2).unsqueeze(0)
    
    # FCNモデルへの画像入力
    output_y = model(y_t.to(device))
    output_y_hat = model(y_hat_t.to(device))
    
    # Segmentation結果を得る
    seg_y = output_y["out"].squeeze().detach().argmax(0).cpu().numpy()
    seg_y_hat = output_y_hat["out"].squeeze().detach().argmax(0).cpu().numpy()
    
    # ピクセル単位でのaccuracyスコアを得る
    acc = accuracy_score(seg_y[seg_y>0].flatten(), seg_y_hat[seg_y>0].flatten())
    acc_list.append(acc)
    
    if (i+1) % 5 == 0:
        plt.subplot(2,5,(i+1) // 5)
        plt.title("$\hat{y},$"+f" ksize=({k},{k})")
        plt.imshow(y_hat)
        plt.axis("off")
        
        plt.subplot(2,5,(i+1) // 5 + 5)
        plt.title("$seg(\hat{y}),$"+f" ksize=({k},{k})")
        plt.imshow(seg_y_hat, vmin=0, vmax=21)
        plt.colorbar()
        plt.axis("off")
        
plt.tight_layout()
plt.show()

plt.rcParams["font.size"] = 14
plt.plot(ksize_cands, acc_list)
plt.xlabel("kernel size")
plt.ylabel("pixel accuracy of FCN score")
plt.grid()
plt.show()

ぼやければぼやけるほどピクセルレベルでのAccuracyスコアが減少していく傾向を確認できました。
本当は、インスタンス単位でのAccuracyやIoUでもFCNスコアは評価されるべきですが、今回はピクセルレベルでのAccuracyスコアとしてFCNスコアの計算方法までの紹介とします。画像と画像の類似度を評価できる手法としてそこまで難しくない実装で実現できるところが良いと思います。

以上

Discussion