SSDのポジティブDBoxを表示してみる

公開:2020/09/28
更新:2020/09/28
7 min読了の目安(約4600字TECH技術記事

SSDのラベル付けに対する損失を評価するのに使われるポジティブデフォルトボックスを可視化させてみたいと思います。実装はSSDのPython実装amdegroot/ssd.pytorchを使用します。

ポジティブデフォルトボックスとネガティブデフォルトボックスは1つ以上の画像と画像中のバウンディングボックスが与えられれば決める事ができて、SSDモデルとは直接関係無く決まります。まず画像をロードするため、次のようにデータローダーを生成します。

# 画像を300x300にリサイズするのに使用します
image_size = 300
# 今回は1画像のみ扱うのでバッチサイズは1にします
batch_size = 1
# transformで本来学習中はデータオーグメントさせますが、
# ここでは可視化のため300x300リサイズとRGBの平均値を引く操作のみ行う
# BaseTransformを使用します。
dataset = VOCDetection(root = '/path/to/root', 
                          transform=BaseTransform(image_size, MEANS))
data_loader = data.DataLoader(dataset, 
    batch_size = batch_size, 
    num_workers=0, 
    shuffle=False, 
    collate_fn = detection_collate)
# ssd.pytorchの関連するコードがcuda前提のため、これを設定しないとエラーになります
torch.set_default_tensor_type('torch.cuda.FloatTensor')

データローダーから画像(images)とバウンディングボックス、ラベル情報(targets)をロードして、デフォルトボックスと全てのバウンディングボックスの間のjaccard係数の最大値を計算し、jaccard係数が最大となったバウンディングボックスに対する位置情報をloc_tに、ラベル情報をconf_tに入れます。ただしjaccard係数の最大値が閾値未満(今回は0.5未満)の場合は背景としてラベル付けし直します。このデフォルトボックスをネガティブデフォルトボックスと呼び、それ以外のデフォルトボックスをポジティブデフォルトボックスと呼びます。

images, targets = next(iter(data_loader))
images = images.cuda()
targets = [ann.cuda() for ann in targets]
priors = PriorBox(voc).forward()
num_priors = priors.size(0)
loc_t = torch.Tensor(batch_size, num_priors,  priors.size(1))
conf_t = torch.LongTensor(batch_size, num_priors)
# loc_t: DBoxからBBoxまでの差分とconf_t: jaccard係数 0.5を閾値としたときのラベル付けを計算
# matchはloc_t、conf_tに値をセットします
match(threshold = 0.5, 
      truths = targets[0][:, :-1].data,  
      priors = priors.data, 
      variances = [0.1, 0.2], 
      labels = targets[0][:, -1].data, 
      loc_t = loc_t, 
      conf_t = conf_t, 
      idx = 0)

jaccard係数の計算は、傾きの無い長方形同士なので、次のように簡単に計算できます。

def calc_jaccard(a, b):
    w = min(a[2], b[2]) - max(a[0], b[0])
    h = min(a[3], b[3]) - max(a[1], b[1])
    # interはa, bが互いに素であれば、w, hのどちらか、またはどちらもが
    # ネガティブになるのでmaxでこのときinterが0になるように調整する
    inter = max(w, 0) * max(h, 0)
    area_a = (a[2] - a[0]) * (a[3] - a[1])
    area_b = (b[2] - b[0]) * (b[3] - b[1])
    return inter / (area_a + area_b - inter)

ポジティブデフォルトボックスと対応するバウンディングボックスを表示してみます。

image = (images[0].to('cpu').detach().numpy().transpose(1, 2, 0) + (MEANS[2], MEANS[1], MEANS[0])).astype(np.uint8).copy()   
image = cv2.resize(image, (image_size, image_size))
indices = [i for i, v in enumerate(list((conf_t > 1).to('cpu').detach().numpy().copy()[0])) if v]
colors = get_colors()
plt.figure(figsize=(16, 12))
for i, idx in enumerate(indices):
    img = image.copy()    
    cx_d, cy_d, w_d, h_d = priors[idx].to('cpu').detach().numpy().copy()
    xmin_d = int((cx_d - w_d / 2) * image_size)
    ymin_d = int((cy_d - h_d / 2) * image_size)
    xmax_d = int((cx_d + w_d / 2) * image_size)
    ymax_d = int((cy_d + h_d / 2) * image_size)
    pt1_d = (xmin_d, ymin_d)
    pt2_d = (xmax_d, ymax_d)    
    cv2.rectangle(img, pt1=pt1_d, pt2=pt2_d, color=colors[5], thickness=2)
    for box in targets[0]:
        pt1 =  (int(box[0] * image_size), int(box[1] * image_size) )
        pt2 = (int(box[2] * image_size), int(box[3] * image_size))
        jaccard = calc_jaccard((pt1_d[0], pt1_d[1], pt2_d[0], pt2_d[1]), (pt1[0], pt1[1], pt2[0], pt2[1]))
        cv2.putText(img, f"{jaccard: .2f}", (pt1[0] + 5 , pt1[1] - 5), cv2.FONT_HERSHEY_PLAIN, 1.5, (0, 255, 0), 2)
        cv2.rectangle(img, pt1=pt1, pt2=pt2, color=colors[10], thickness=2)
    plt.subplot(3 , 4, i + 1)
    plt.axis('off')
    plt.imshow(img)    
plt.show()

画像中の数値はjaccard係数になります。表示からも分かるように、8732個あるデフォルトボックスのうち、ポジティブデフォルトボックスは12個で、ほとんどがネガティブデフォルトボックスに分類されることが分かります。最後にポジティブデフォルトボックスを全て重ねて表示させてみます。

img_path = os.path.join(os.environ['HOME'], "data", "dogs_out", "dog_ssd_test-PascalVOC-export", "JPEGImages", "dog-3407906_640.jpg")
img = cv2.imread(img_path)
img_h_, img_w_, _ = img.shape
img = img[:, :, (2, 1, 0)].copy()
for i, idx in enumerate(indices):
    cx, cy, w, h = priors[idx].to('cpu').detach().numpy().copy()
    xmin = int((cx - w / 2) * img_w_)
    ymin = int((cy - h / 2) * img_h_)
    xmax = int((cx + w / 2) * img_w_)
    ymax = int((cy + h / 2) * img_h_)
    pt1 = (xmin, ymin)
    pt2 = (xmax, ymax)    
    cv2.rectangle(img, pt1=pt1, pt2=pt2, color=colors[i % 10], thickness=2)
plt.axis('off')
plt.imshow(img)
plt.show()

ここで使用している元画像はここ (Pixabay License: 商用利用無料 帰属表示は必要ありません)から640x426のサイズでダウンロードしたものになります。