SSDのデフォルトボックスの可視化

4 min読了の目安(約2800字TECH技術記事

SSDのラベル付けに対する損失を評価するのに使われるハードネガティブマイニングを理解するために、まずここではデフォルトボックスを可視化させてみたいと思います。実装はSSDのPython実装amdegroot/ssd.pytorchを使用します。

特にここではdata.configで定義されているSSD300の設定を使用したいと思います。この設定は小さな正方形、大きな正方形、1:21:3のアスペクト比の長方形が使用されています。ただし、小さな正方形のデフォルトボックスの大きさが30, 213, 264のときはアスペクト比1:3のデフォルトボックスは使われていません。

import os
import cv2
import matplotlib.pyplot as plt
from layers.functions.prior_box import PriorBox
from data.config import voc
# RGBカラーの配列を返します
def get_colors():
    cmap = plt.get_cmap('prism', 10)
    return [(int(cmap(i)[0] * 255), int(cmap(i)[1] * 255), int(cmap(i)[2]*255)) for i  in range(10)]
# デフォルトボックス付きの画像を返します
def img_with_dbox(priors, img_h, img_w, idx, i):
    img = cv2.imread('/path/to/image')
    img = img[:, :, (2, 1, 0)].copy()
    img = cv2.resize(img, (300, 300))
    img_h, img_w, _ = img.shape
    colors = get_colors()
    cx, cy, w, h = priors[idx].to('cpu').detach().numpy().copy()
    xmin = round((cx - w / 2) * img_w)
    ymin = round((cy - h / 2 ) * img_h)
    xmax = round((cx + w / 2) * img_w)
    ymax = round((cy + h / 2) * img_h)
    #print(f"({xmin}, {ymin}, {xmax}, {ymax})")
    #print(f"(w, h) = ({xmax - xmin}, {ymax - ymin})")
    pt1 = (xmin, ymin)
    pt2 = (xmax, ymax)    
    cv2.rectangle(img, pt1=pt1, pt2=pt2, color=colors[i], thickness=2)
    return img
# デフォルトボックスを書き出します
def show_dboxes(priors, offset, n):
    plt.figure(figsize=(n * 2, n))
    for j, i in enumerate(range(offset, offset + n)):
        plt.subplot(1, n, j + 1)
        plt.axis('off')
        plt.imshow(img_with_dbox(priors, 300, 300, i, j))
    plt.show()

以上の関数を定義した上で、次のように書くと一通りのデフォルトボックスのパターンを確認することができます。ここで使用している元画像はここ (Pixabay License: 商用利用無料 帰属表示は必要ありません)から640x426のサイズでダウンロードしたものになります。

show_dboxes(priors, (38 * 5 + 30)  * 4, 4)
show_dboxes(priors, 38 * 38  * 4 + (19 * 5 + 13) * 6, 6)
show_dboxes(priors, 38 * 38  * 4 + 19 * 19 * 6 + (10 * 3 + 5) * 6, 6)
show_dboxes(priors, 38 * 38  * 4 + 19 * 19 * 6 + 10 * 10  * 6 +  (5 * 2 + 2) * 6, 6)
show_dboxes(priors, 38 * 38  * 4 + 19 * 19 * 6 + 10 * 10 * 6 + 5 * 5 * 6 + (3 * 1 + 1) * 4, 4)
show_dboxes(priors, 38 * 38  * 4 + 19 * 19 * 6 + 10 * 10 * 6 + 5 * 5 * 6 + 3* 3 * 4, 4)