👋

DeepLabをカスタムデータセットで学習するための初期重みを作る

2022/01/14に公開

tensorflow/modelsdeeplabでカスタムデータセットで学習するとき、データセットのクラス数が背景入れて21の場合以外だと、そのままだと学習できません。クラス数が21より小さければ、21のまま学習することはできますが、やはり精度は落ちます(あとで比較します)。com.unity.perceptionのチュートリアルで作ったデータセットを使って、初期重みの作り方を説明します。簡単のため、データセットからtfrecordを生成する手順については省略します。

準備

データセットの定義を追加します。名前はmy_first_perceptionにしました。クラス数は背景入れて11です。ignore_label255にしていますが、実際にはこの値は利用されません。オプショナルなのか分からなかったので、ダミー値として設定しています。vocデータセットだと白線で枠を指定しているためこのようなプロパティが必要になります。

data_generator.py
@@ -90,6 +90,24 @@ _PASCAL_VOC_SEG_INFORMATION = DatasetDescriptor(
     ignore_label=255,
 )
 
+_MY_FIRST_PERCEPTION_INFORMATION = DatasetDescriptor(
+    splits_to_sizes={
+        'train': 1000,
+        'val': 100,
+    },
+    num_classes=11,
+    ignore_label=255,
+)
+
 _ADE20K_INFORMATION = DatasetDescriptor(
     splits_to_sizes={
         'train': 20210,  # num of samples in images/training
@@ -103,6 +121,8 @@ _DATASETS_INFORMATION = {
     'cityscapes': _CITYSCAPES_INFORMATION,
     'pascal_voc_seg': _PASCAL_VOC_SEG_INFORMATION,
     'ade20k': _ADE20K_INFORMATION,
+    'my_first_perception': _MY_FIRST_PERCEPTION_INFORMATION,
 }
 
 # Default file pattern of TFRecord of TensorFlow Example.

local_test.shの中の学習部分を取り出して、train.bashを作り実行します(VOCの学習の準備はできているものとします。後で利用する、分類クラス数21のmodel.ckpt-xxx.metaを入手するため)。実行後、次のような変更を加えますがdeeplab/datasets/my_first_perception/tfrecordにカスタムデータセットのtfrecordがあるとします。このスクリプトを実行すると、学習前のランダムな重みを得ます。このランダムな重みのバックボーン部分をxception_65のvocデータセット用に作られた初期の重みと置き換えることで、分類クラス数11の場合の初期重みを作ります。

train.bash
@@ -41,8 +41,8 @@
 # Go back to original directory.
 
 # Set up the working directories.
-PASCAL_FOLDER="pascal_voc_seg"
+PASCAL_FOLDER="my_first_perception"
 INIT_FOLDER="${WORK_DIR}/${DATASET_DIR}/${PASCAL_FOLDER}/init_models"
 TRAIN_LOGDIR="${WORK_DIR}/${DATASET_DIR}/${PASCAL_FOLDER}/${EXP_FOLDER}/train"
 EVAL_LOGDIR="${WORK_DIR}/${DATASET_DIR}/${PASCAL_FOLDER}/${EXP_FOLDER}/eval"
@@ -52,11 +52,12 @@
 PASCAL_DATASET="${WORK_DIR}/${DATASET_DIR}/${PASCAL_FOLDER}/tfrecord"
 
 # Train 10 iterations.
-NUM_ITERATIONS=10
+NUM_ITERATIONS=1
 python "${WORK_DIR}"/train.py \
   --logtostderr \
-  --train_split="trainval" \
+  --train_split="train" \
   --model_variant="xception_65" \
   --atrous_rates=6 \
   --atrous_rates=12 \
   --atrous_rates=18 \
@@ -66,6 +67,5 @@
   --train_batch_size=4 \
   --training_number_of_steps="${NUM_ITERATIONS}" \
   --fine_tune_batch_norm=true \
-  --tf_initial_checkpoint="${INIT_FOLDER}/xception/model.ckpt" \
   --train_logdir="${TRAIN_LOGDIR}" \
   --dataset_dir="${PASCAL_DATASET}" \
   --dataset="${PASCAL_FOLDER}"

xception_65の初期重みを得ておきます。

$ curl -OL http://download.tensorflow.org/models/deeplabv3_pascal_train_aug_2018_01_04.tar.gz

カスタムデータセット用の初期重みを作る

ここからはjupyterで順に実行しているつもりで読んでください。使用するモジュールをインポートします。

import tensorflow as tf
import re

xception関係のVariable名をxception_vars、残りをdeeplab_varsに入れます。

g = tf.Graph()
xception_vars = []
deeplab_vars = []
with g.as_default():
    tf.train.import_meta_graph('/path/to/deeplab/datasets/my_first_perception/exp/train_on_train_set/train/model.ckpt-0.meta')
    print(len(tf.all_variables())) #=> 1173
    xception_vars = list(map(lambda v: v.name , filter(lambda v: v.name.startswith('xception_65') and 'Momentum' not in v.name, tf.all_variables())))
    deeplab_vars = list(map(lambda v: v.name , filter(lambda v: not v.name.startswith('xception_65') or ('Momentum' in v.name), tf.all_variables())))

Variableの数が合っているか確認します。

print(len(xception_vars)) #=> 660
print(len(deeplab_vars)) #=> 513
print(len(xception_vars) + len(deeplab_vars)) #=> 1173

xception、その他、保存用のSavers1, s2, sを作ります。

tf.reset_default_graph() # 要らないかも
p = re.compile(r'(.+):\d+')
g = tf.Graph()
with g.as_default():
    tf.train.import_meta_graph('/path/to/deeplab/datasets/my_first_perception/exp5/train_on_train_set/train/model.ckpt-0.meta')
    deeplab_var_list = {}
    for name in deeplab_vars:
        name_ = p.match(name)[1]
        deeplab_var_list[name_] = g.get_tensor_by_name(name)
    tf.train.import_meta_graph('/path/to/deeplab/datasets/pascal_voc_seg/exp/train_on_trainval_set/train/model.ckpt-0.meta')
    xception_var_list = {}
    for name in xception_vars:
        name_ = p.match(name)[1]
        xception_var_list[name_] = g.get_tensor_by_name(name)
    var_list = xception_var_list.copy()
    var_list.update(deeplab_var_list)
    s1 = tf.train.Saver(var_list=deeplab_var_list)
    s2 = tf.train.Saver(var_list=xception_var_list)
    s = tf.train.Saver(var_list=var_list)

s1, s2をリストアし、sで保存します。こうしてできたmodel.ckptは分類クラス数11の初期重みとして使えます。

with tf.Session(graph=g) as sess:
    s1.restore(sess, '/path/to/deeplab/datasets/my_first_perception/exp5/train_on_train_set/train/model.ckpt-0')
    s2.restore(sess, '/path/to/deeplab/datasets/pascal_voc_seg/exp6/train_on_trainval_set/train/model.ckpt-0')
    s.save(sess, '/path/to/deeplab/datasets/my_first_perception/init_models/xception_11/model.ckpt')

結果

最後に結果を比較してみます。大体どの画像でやっても同じ傾向なのですが、分類クラス数を最適化した方が分類の誤りが少ないのが分かります。おおらかな目で見れば、どちらも大体位置的に合っているのはPillowとかで合成すると確認出来ます。
分類クラス数21のままの学習


分類クラス数11での学習


バックボーンがMobileNetV2の場合

実はtensorflow/modelsのドキュメントでは、MobileNetV2の初期重みは公開されていなくて、MobileNetV2の学習済みモデルが紹介されているだけです。なので、それを元に初期の重みを作る必要があります。使う学習済みモデルは次のリンクからダウンロードします。

$ curl -OL https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_224.tgz

あとは同様にして、次のようにすれば作れます。

import tensorflow as tf
from nets.mobilenet import mobilenet_v2
import re
p = re.compile(r'(.+):\d+')
g = tf.Graph()
mobilenet_vars = []
deeplab_vars = []
with g.as_default():
    tf.train.import_meta_graph('/path/to/datasets/my_first_perception/exp0/train_on_trainval_set_mobilenetv2/train/model.ckpt-0.meta')
    print(len(tf.all_variables()))
    mobilenet_vars = list(map(lambda v: v.name , filter(lambda v: v.name.startswith('MobilenetV2') and 'Momentum' not in v.name, tf.all_variables())))
    deeplab_vars = list(map(lambda v: v.name , filter(lambda v: not v.name.startswith('MobilenetV2') or ('Momentum' in v.name), tf.all_variables())))
tf.reset_default_graph()
g = tf.Graph()
with g.as_default():
    tf.train.import_meta_graph('/path/to/datasets/my_first_perception/exp0/train_on_trainval_set_mobilenetv2/train/model.ckpt-0.meta')
    deeplab_var_list = {}
    for name in deeplab_vars:
        name_ = p.match(name)[1]
        deeplab_var_list[name_] = g.get_tensor_by_name(name)
    tf.train.import_meta_graph('/path/to/checkpoint/mobilenet_v2/mobilenet_v2_1.0_224.meta')
    mobilenet_var_list = {}
    for name in mobilenet_vars:
        name_ = p.match(name)[1]
        mobilenet_var_list[name_] = g.get_tensor_by_name(name)
    var_list = mobilenet_var_list.copy()
    var_list.update(deeplab_var_list)
    s1 = tf.train.Saver(var_list=deeplab_var_list)
    s2 = tf.train.Saver(var_list=mobilenet_var_list)
    s = tf.train.Saver(var_list=var_list)
with tf.Session(graph=g) as sess:
    s1.restore(sess, '/path/to/datasets/my_first_perception/exp0/train_on_trainval_set_mobilenetv2/train/model.ckpt-0')
    s2.restore(sess, "/path/to/checkpoint/mobilenet_v2/mobilenet_v2_1.0_224")
    s.save(sess, "/path/to/datasets/my_first_perception/init_models/mobilenetv2_11/model.ckpt")

バックボーンがMobileNetV2の場合の結果

学習できていることが確認できます。

Discussion