CenterTrackで動体をトラッキングしてみる
モデル概要
目的
同時検出・追跡のアイデアを組み合わせることで、複数物体のトラッキングを単純化すること
コアとなるアイディア
- 各オブジェクトは、バウンディングボックスの中心にある1つの点で表され、この中心点を時間軸に沿って追跡する。
- オブジェクトの中心点の特定にはCenterNetによる検出器が用いられている。
- 現在のフレームにおけるオブジェクトの中心から、前フレームのオブジェクトの中心までの差分ベクトルを出力するように検出器を学習
- フレーム間の距離のみに基づいたマッチング方法を用いている
ネットワーク
- CenterNetを用いており、物体検出のタスクにおいて、あらゆる検出モデルよりも速度と精度の面で優れている
- CenterNet内の構造は、(a)Hourglass-104, (b)ResNet-18 & ResNet-101, (c,d)DLA-34を用いている(dはcを改良して用いている)。
損失関数
- 焦点損失・中心点の回帰・バウンディングボックスのサイズ、これら損失項の重み付けされた総和
- 焦点損失:背景と前景とで極端な不均衡がないかの調整
- 中心点の回帰:オブジェクト内において、周囲に9つの点を持つピーク点の回帰
- バウンディングボックスのサイズ:中心点をもとに適切な幅と高さを回帰する
改良点
- オブジェクトを点としてトラッキングすることで、以下の二つの要素を簡略化
- トラッキング条件付きの検出の簡略化
- 時間を超えたオブジェクトの関連付けの単純化(スパースオプティカルフローのような単純な変位予測)
問題や課題
- 認識可能な距離にある物体追跡の精度及び速度を向上させるため、長距離の物体については精度の保証をしていない
動かしてみる
データセット準備
-
今回はKITTIデータセットを使う
http://www.cvlibs.net/datasets/kitti/ -
上のURLに入り、"raw data"をクリック、クリック後にスクロールダウンして"Campas"の"2011_09_28_drive_0039 (1.4GB)"の[synced+rectified data]をダンプする
モデル準備
-
cloneする
git clone https://github.com/xingyizhou/CenterTrack
-
今回は物体を分類しながらトラッキングを行うために、以下から"coco_tracking"のところまでスクロールダウンし、"model"部分をクリックしてモデルをダウンロードする
https://github.com/xingyizhou/CenterTrack/blob/master/readme/MODEL_ZOO.md -
新しく"models"フォルダをrootに作成し、そこへ"coco_tracking.pth"を置く
-
一部修正箇所があるため、/src/lib/opts.pyの397行目を以下に変更
opt = self.parse(args)
-
必要なパッケージのインストール
pip install torch==1.4 torchvision==0.5 pytest==3.8 folium==0.2.1 opencv-python pip install -r requirements.txt
-
DCNが必要なため、クローンしてからmake.shを実行する
cd CenterTrack-master/src/lib/model/networks/ git clone https://github.com/CharlesShang/DCNv2.git cd DCNv2/ ./make.sh
-
rootに戻り、video/下まで行き、KITTIデータセットの2011_09_28_drive_0039_sync/image_03/dataを置きます。
-
データセットを置いたら、同じくvideo/下にimg_to_mp4.pyを作成し、以下のコードをコピペして、このpyファイルをvideo/で実行してください。
CenterTrack-master/video/img_to_mp4.pyimport cv2 import glob def main(): files = glob.glob("data/*.png") h, w, _ = cv2.imread(files[0]).shape out = cv2.VideoWriter('sample.mp4',cv2.VideoWriter_fourcc(*'MP4V'), 5.0, (w, h)) for file in sorted(files): out.write(cv2.imread(file)) out.release() if __name__ == '__main__': main()
python img_to_mp4.py
-
下にあるような動画ができればデータセットの準備は終わりです。
テスト
- テストする前に、demo.pyの50~53行目までを以下に変更。mp4ファイルを読み込む際に正しく高さと幅を与えないとエラーとなるため、プログラム内で取得するように変更します。CenterTrack-master/src/demo.py
cam = cv2.VideoCapture('../videos/{}'.format(out_name)) width = int(cam.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cam.get(cv2.CAP_PROP_FRAME_HEIGHT)) fourcc = cv2.VideoWriter_fourcc(*'h264') out = cv2.VideoWriter('../results/{}'.format(opt.exp_id + '_' + out_name), fourcc, opt.save_framerate, (width, height))
- テスト結果のmp4を作成できるように、rootに"results"を作成しておきます。
mkdir CenterTrack-master/results
- ここまで準備ができたら、src/以下まで移動し、以下コマンドを実行してroot/video下にあるmp4ファイルに対してトラッキングさせてみます。
python demo.py tracking --load_model ../models/coco_tracking.pth --demo ../video/sample.mp4 --save_video --gpus -1
結果
- 人や物毎にIDが振られ、別々にトラッキングされていることが確認できます。
- 駐輪場の自転車に差分ベクトルが出ているのが気になりますね。。
最後に
今回はpretrainされたモデルを用いた試みですが、自前のデータセットで追加学習してみるのも面白いかもしれません。
参考文献
- CenterNet: Objects as Points
- CenterTrack: Tracking Objects as Points
- Hourglass: Stacked Hourglass Networks for Human Pose Estimation
- ResNet: Deep Residual Learning for Image Recognition
- DLA: Deep Layer Aggregation
Discussion