😊

pytorchでefficientdetを使ってみる

2020/12/28に公開

efficientdetのpytorch実装のrwightman/efficientdet-pytorchを使って、検証データの推論結果を表示してみます。このプロジェクト内で学習方法、カスタムデータセットを扱う方法などについては載っているのですが、なぜかデモやexampleが載っていません。

まずこのプロジェクトトップにjupyter notebookのnoteを作成します。

from effdet import create_model, create_dataset, create_loader
from effdet.data import resolve_input_config
import torch
import matplotlib.pyplot as plt
import cv2
import os

このリポジトリ内にあるvalidate.pyではモデルが推論モードが使用されているので、そのプログラムのオプションのデフォルト値を参考にしながら、create_modelの引数を決めます。checkpoint_pathは自分で学習を走らせた場合は、次のセル内のような名前になると思います。

bench = create_model(
    'efficientdet_d0', # d0 ~ d7
    bench_task='predict',
    num_classes=20,
    pretrained=False,
    redundant_bias=None,
    soft_nms=None,
    checkpoint_path='./output/train/yyyymmdd-hhnnss-efficientdet_d0/checkpoint-n.pth.tar',
    checkpoint_ema='use_ema',
)
bench = bench.cuda() #cudaを使う
bench.eval() #推論モード

検証データのロードはVOCデータセットであれば次のように行います。_で受けているほうが、訓練データになります。

_, dataset = create_dataset('voc0712', './VOCdevkit')

データローダーのオプションもvalidate.pyのデフォルトオプションを頼りに決めます。今回は1つの写真を与えることを想定するので、batch_sizeは1にします。

model_config = bench.config
input_config = resolve_input_config({}, model_config)
loader = create_loader(
        dataset,
        input_size=input_config['input_size'],
        batch_size=1,
        use_prefetcher=True,
        interpolation='bilinear',
        fill_color=input_config['fill_color'],
        mean=input_config['mean'],
        std=input_config['std'],
        num_workers=1,
        pin_mem=False)

推論モードのときに出力されるバウンディングボックスの座標は入力画像のサイズに合わせて確率上位100位が返ってきます。cv2.circleでマークを付けたり、cv2.rectangleで囲んで画像を表示すれば、結果を確認することができます。

IMG_DIR = '/path/to/img_dir'
parser = dataset.parser
with torch.no_grad():
    for input, target in loader:
        img_name = parser.img_infos[int(target['img_idx'][0])]['file_name']
        output = bench(input, img_info=target)[0]
        img = cv2.imread(os.path.join(IMG_DIR, img_name), cv2.IMREAD_COLOR)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        for i in range(output.size(0)):
            if output[i, 4] < 0.3: # 0.3は閾値です。適当に変えてください
                break
            xmin, ymin, xmax, ymax, pred, label = output[i]
            #cv2.rectangle(img, pt1=(int(xmin), int(ymin)), pt2=(int(xmax), int(ymax)), color=(255, 0, 0), thickness=4)
            cx = int(xmin) + int((float(xmax) - float(xmin)) / 2)
            cy = int(ymin) + int((float(ymax) - float(ymin)) / 2)
            cv2.circle(img, (cx, cy), 5, (0, 255, 0), thickness=-1)
        plt.imshow(img)
        plt.show()

Discussion