🚄

RepConv: テスト時に訓練時と等価でシンプルなアーキテクチャに変換することで高速化かつ高精度を達成

2022/07/17に公開

はじめに

Re-parametrizationという手法を使って、精度を維持しつつ高速化を達成したRepVGGという手法について紹介する。

高精度と高速化を同時に達成

まず、本論文で提案するRepConvというアーキテクチャのパフォーマンス-推論時間のトレードオフを図1に示す。これからわかる通り、提案手法はResNetと比べて、同じ精度を達成するのに約半分の推論時間しか必要としないことがわかる。

図1. 各モデルのパフォーマンス-推論時間のトレードオフ

Multi-path v.s. Single-path

論文の手法の詳細を説明する前に、背景知識を補足しておく。

CNNはSingle-pathとMulti-pathのネットワークに分類される。
前者は前段の入力を畳み込み層を経由して後段に順次伝搬していく、最も基本的なアーキテクチャである(例: VGG)。また、後者はResNetのように、入力の特徴マップを記憶しておき、畳み込み層を経由した後の出力の特徴マップと加算、あるいはチャネル方向に結合するようなアーキテクチャである。

Single-pathアーキテクチャのメリット

Multi-pathと比較したSingle-pathのネットワークのメリットには、以下の3点があげられる。

  1. 処理速度が高速
  2. メモリ効率が良い
  3. アーキテクチャの選択がより柔軟

1点目について、Single-pathのネットワークはMulti-pathのネットワークに比べて、a) メモリアクセスコスト(MAC)、及び、b) 並列計算の性能に優れているため、一般的に高速的に動作する。

また、2点目について、Multi-pathのネットワークは、畳み込み計算が終わるまで入力の特徴マップをメモリ上に保持しておく必要があるため、2-branchのネットワークの場合はメモリの最大消費が1-branchのネットワークに比べて2倍になる(図2)。

図2: 2-branchモデルと1-branchモデルのブロックごとのメモリ消費

最後に、Multi-pathのネットワークには以下のようなアーキテクチャ上の制約が課される。

a) 畳み込み後の特徴マップのチャネル数やサイズが入力の特徴マップと同じでなければならない

したがって、不要なchannelのpruningなどが不可能である。Single-pathのネットワークにはこのような制約はない。

Single-pathアーキテクチャのデメリット

一方で、Single-pathのネットワークのデメリットは、精度である。
Multi-pathのネットワークは、複数のサブモデルのアンサンブルとみなすことができるため、精度の上でSingle-pathのネットワークはこれに及ばない。

精度と速度を同時に達成する: Re-parametrization

前述のように、Single-pathのネットワークはMulti-pathのネットワークと比べて精度が劣ることがデメリットだが、本論文ではRe-parametrizationという手法を使ってMulti-pathで訓練したネットワークを推論時に、これと等価なSingle-pathのネットワークに変換するアプローチをとる(図3)。

図3: 提案手法のアーキテクチャ

パラメータの変換方法

以下では、訓練時の重みを使って、推論時の重みを計算する方法を説明する。
基本的な構成要素として、図4のような3つのbranchからなるブロックを考える。

  1. 3x3 conv layer + BN
  2. 1x1 conv layer + BN
  3. identiry + BN

図4: パラメータ変換方法

Step1: Conv + BNをConvに変換する

BNレイヤは、学習が終わると単なる定数倍とバイアス項の足し算で表されるので、Convレイヤ+BNの計算は、以下の式で表される。



Step2: 1x1 conv, identity mappingを3x3 convに変換する

次に、1x1 convを3x3 convに変換する。
これは単に3x3フィルターの中心に1x1フィルターの値をコピーし、その他の重みを0とすることで得られる。
また、identity mappingは単に行列を重みに持つ1x1convと等価であるから、前述の方法と同じ方法で3x3 convに変換することができる。

Step3: 3つの3x3 convの足し合わせ

最後に、全ての3x3 convフィルターの重みを足し合わせることで、1つの等価な3x3 convレイヤーに変換できる。

なお、注意点として、以下の2つがある。

  1. 3x3 convと1x1 convのstrideは同じでなければならない
  2. 3x3 convのpaddingは1x1 convのpadding+1でなければならない
GitHubで編集を提案

Discussion