SSDの位置に対する損失を可視化してみる

5 min読了の目安(約4900字TECH技術記事

SSDモデルの学習で使われる損失は画像分類に対するものと、位置に対するものの和になっています。ここでは位置にたいする損失をバウンディングボックスと補正後のデフォルトボックスとともに表示してます。実装はSSDのPython実装amdegroot/ssd.pytorchを使用します。

SSDのハードネガティブマイニングの結果を可視化してみる』のコードの再掲になりますが、教師データをSSDモデルが計算する位置情報とラベル情報と同様のフォーマットに変換したような量を計算します。loc_tは位置に関するものでconf_tは画像の分類情報になります。

image_size = 300
batch_size = 1
n_class =  len(VOC_CLASSES) + 1
dataset = VOCDetection(root = '/path/to/root', 
                          transform=BaseTransform(image_size, MEANS))
data_loader = data.DataLoader(dataset, batch_size = batch_size, num_workers=0, shuffle=False, collate_fn = detection_collate)
torch.set_default_tensor_type('torch.cuda.FloatTensor')
images, targets = next(iter(data_loader))
images = images.cuda()
targets = [ann.cuda() for ann in targets]
priors = PriorBox(voc).forward()
num_priors = priors.size(0)
loc_t = torch.Tensor(batch_size, num_priors,  priors.size(1))
conf_t = torch.LongTensor(batch_size, num_priors)
# loc_t: DBoxからBBoxまでの差分とconf_t: jaccard係数 0.5を閾値としたときのラベル付けを計算
for idx in range(batch_size):
    match(threshold = 0.5, 
          truths = targets[idx][:, :-1].data,  
          priors = priors.data, 
          variances = [0.1, 0.2], 
          labels = targets[idx][:, -1].data, 
          loc_t = loc_t, 
          conf_t = conf_t, 
          idx = idx)

次に、SSDネットワークを学習モードで生成して、画像情報から位置情報に関する出力を得ます。

net = build_ssd('train', image_size, n_class)
net.load_state_dict(torch.load('/path/to/weight')
net = net.cuda()
net.train()
loc_data, _ , _ = net(images)

ポジティブデフォルトボックスに分類される位置情報のみ抜き出します。loc_pがSSDが推測した結果で、loc_tが教師データになります。

pos = conf_t > 0
pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)
loc_p = loc_data[pos_idx].view(-1, 4)
loc_t = loc_t[pos_idx].view(-1, 4)

位置情報に関する損失をsmooth_l1_loss関数で評価します。この関数は平均絶対誤差(MAE)や平均二乗誤差(MSE)の仲間で、細かい差はありますが、連続値をとる変数の誤差を計算するのに使われる関数です。各デフォルトボックスに対する損失は、4変数の和になります。更にそれを平均したものloss_lが、学習で使われる位置情報分の損失になります。

l_losses = F.smooth_l1_loss(loc_p, loc_t, reduction='none')
l_losses = torch.sum(l_losses, 1)
l_losses = l_losses.to('cpu').detach().numpy().copy()
# 位置情報に関する損失
loss_l = np.average(l_losses) 
print(loss_l)

1つの画像に対する位置に関する損失は、ポジティブデフォルトボックスのみ寄与するので、次のようにして、各デフォルトボックスに対する損失とデフォルトボックスとバウンディングボックスを表示させられます。

# ポジティブデフォルトボックスのindex
indices = [i for i, v in enumerate(list(pos.to('cpu').detach().numpy().copy()[0])) if v]
plt.figure(figsize=(16, 12))
image = (images[0].to('cpu').detach().numpy().transpose(1, 2, 0) + (MEANS[2], MEANS[1], MEANS[0])).astype(np.uint8).copy()   
for i, idx in enumerate(indices):   
    img = cv2.resize(image.copy(), (image_size, image_size))
    cx_d, cy_d, w_d, h_d = priors[idx].to('cpu').detach().numpy().copy()
    xmin_d = int((cx_d - w_d / 2) * image_size)
    ymin_d = int((cy_d - h_d / 2) * image_size)
    xmax_d = int((cx_d + w_d / 2) * image_size)
    ymax_d = int((cy_d + h_d / 2) * image_size)
    pt1_d = (xmin_d, ymin_d)
    pt2_d = (xmax_d, ymax_d)
    #デフォルトボックスを描画する処理
    #見えやすいとは言えないので、ここではコメントアウト 
    #cv2.rectangle(img, pt1=pt1_d, pt2=pt2_d, color=(0, 255, 0), thickness=2)
    loc = loc_data.view(-1, 4)[idx].to('cpu').detach().numpy().copy()
    cx_p = cx_d * (1 + 0.1 * loc[0])
    cy_p = cy_d * (1 + 0.1 * loc[1])
    w_p = w_d * np.exp(0.2 * loc[2])
    h_p = h_d * np.exp(0.2 * loc[3])
    xmin_p = int((cx_p - w_p / 2) * image_size)
    ymin_p = int((cy_p - h_p / 2) * image_size)
    xmax_p = int((cx_p + w_p / 2) * image_size)
    ymax_p = int((cy_p + h_p / 2) * image_size)
    pt1_p = (xmin_p, ymin_p)
    pt2_p = (xmax_p, ymax_p)    
    cv2.rectangle(img, pt1=pt1_p, pt2=pt2_p, color=(255, 0, 0), thickness=2)
    for box in targets[0]:
        label = int(box[4])
        pt1 =  (int(box[0] * image_size), int(box[1] * image_size) )
        pt2 = (int(box[2] * image_size), int(box[3] * image_size))
        jaccard = calc_jaccard((pt1_p[0], pt1_p[1], pt2_p[0], pt2_p[1]), (pt1[0], pt1[1], pt2[0], pt2[1]))
        cv2.putText(img, f"J:{jaccard: .2f}", (200, 30), cv2.FONT_HERSHEY_PLAIN, 1.5, (0, 255, 0), 2)
        cv2.putText(img, f"L:{l_losses[i]: .2f}", (200, 60), cv2.FONT_HERSHEY_PLAIN, 1.5, (0, 255, 0), 2)
        cv2.rectangle(img, pt1=pt1, pt2=pt2, color=(0, 0, 255), thickness=2)
    plt.subplot(3 , 4, i + 1)
    plt.axis('off')
    plt.imshow(img)
plt.show()

学習開始時の重みを使った場合、ほぼデフォルトボックスとバウンディングボックスが表示されたような表示になります。画像中のJ: 0.54というのはjaccard係数でL: 4.64というのは損失になります。

学習済みモデルの場合は、デフォルトボックスから位置補正を入れたボックスがほぼバウンディングボックスに一致するようになります。

ここで使用している元画像はここ (Pixabay License: 商用利用無料 帰属表示は必要ありません)から640x426のサイズでダウンロードしたものになります。