ssd.pytorchのDataLoderが返すtensorを画像に戻して表示する

2 min読了の目安(約1400字TECH技術記事

SSDのPython実装amdegroot/ssd.pytorchのDataLoaderから得られるtorch.Tensor型の画像を変換してpyplotで画像として表示させる方法を示します。やっていることはSSDのDatasetが行っている変換の逆変換をしているだけですが、DataLoaderでロードした画像を可視化するのに役立つと思います。

import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from data import VOCDetection, BaseTransform, MEANS
import torch.utils.data as data
def show_image(image, h, w):
    img = (image.to('cpu').detach().numpy().transpose(1, 2, 0) + (MEANS[2], MEANS[1], MEANS[0])).astype(np.uint8).copy()   
    img = cv2.resize(img, (w, h))
    plt.axis('off')
    plt.imshow(img)
    plt.show()

ここで定義したshow_imagetorch.Tensorを再びnumpy.ndarrayに変換します。次のコードを実行すると、300x300にリサイズした画像を再び元の画像サイズに戻したものと見た目レベルで同じ状態になることが確認できます。

dataset = VOCDetection(root = '/path/to/root', image_sets=[('2007', 'test')], 
                          transform=BaseTransform(300, MEANS))
data_loader = data.DataLoader(dataset, batch_size = 1, num_workers=0, shuffle=False)

img_path = os.path.join('/path/to/root', 'VOC2007', 'JPEGImages', '000001.jpg')
img = cv2.imread(img_path)
h, w, _ = img.shape
img = cv2.resize(img, (300, 300))
img = cv2.resize(img, (w, h))
img = img[:, :, (2, 1, 0)]
plt.axis('off')
plt.imshow(img) 
plt.show()#元画像を300x300に変換し元に戻した画像を表示

images, _ = next(iter(data_loader))
show_image(images[0], h, w) #DataLoarderがロードした画像を表示

DataLoaderからは元画像の大きさは分からないので、実践的には適当な幅と高さを手で与える必要があります。