📌

行列積演算子を用いたニューラルネットの重みパラメータ数の削減

に公開

はじめに

この記事は論文[1] の解説です。要旨を一言で言うと、深層ニューラルネットワークにおける線形変換を行列積演算子(MPO)で表現することで、パラメータ数を大幅に削減しつつ予測精度を維持または向上させうる、という手法になります。

論文内の手法の解説

本論文の手法は図 1 および式(3),(5) に要約されます。以下ではこれらの図や式の意味を説明します。

考える設定

入力の次元N_x, 出力の次元N_yの線形変換を考えます。すなわち、N_x 次元のベクトルに左から (N_y, N_x) 行列を作用させ、N_y次元のベクトルを取得する過程です。この (N_y, N_x) サイズの変換行列のパラメータ数を減らし、線形変換の計算時間を減らすことが目的です。

本論文の手法を用いる前の計算時間、パラメタ数

この線形変換を素朴に行うと、計算時間は \mathcal{O}(N_xN_y) ,パラメタ数は \mathcal{O}(N_xN_y) となります。

本論文の手法

論文中の式(4)のように、入力の次元および出力の次元を適当に分解します。

\begin{equation} \prod_{k=1}^{n}I_{k} = N_x, \quad \prod_{k=1}^{n}J_{k} = N_y \tag{1} \end{equation}

続いて、適当な定数Dを用いて、以下の形状のテンソル群を考えます。(以下ではこの定数Dをボンド次元と呼びます)

\begin{aligned} w^{(1)} &: (J_1, I_1, D) \\ w^{(l)} &: (J_l, I_l, D, D) \quad (\text{for } 2 \leq l \leq n-1) \\ w^{(n)} &: (J_n, I_n, D) \end{aligned}

これらのテンソル群を用いて、元のパラメータ行列Wを以下で表現します。

W_{j_1,\ldots,j_n,i_1,\ldots,i_n} = \sum_{d_1,\ldots,d_{n-1}=1}^D w^{(1)}_{j_1,i_1,d_1} \left( \prod_{l=2}^{n-1} w^{(l)}_{j_l,i_l,d_{l-1},d_l}\right) w^{(n)}_{j_n,i_n,d_{n-1}}

ここで、W_{j_1,\ldots,j_n,i_1,\ldots,i_n} は元の行列 W_{J,I} の成分を表します。IJ はそれぞれ I=(i_1,i_2,\ldots,i_n)J=(j_1,j_2,\ldots,j_n) と混合基数(mixed-radix)表現されたインデックスです。

I=\sum{k=1}^n i_k \prod_{m=k+1}^n I_m

行列やテンソル計算に慣れた読者には、N_x 次元のインデックスを (i_1,i_2,\ldots,i_n) という複数のインデックスに reshape した、と考えると分かりやすいかもしれません。

上記を形式的に、以下のように表します(論文中式(5))

W_{j_1,\ldots,j_n,i_1,\ldots,i_n} = \mathrm{Tr}\left(w^{(1)}[j_1,i_1]\,w^{(2)}[j_2,i_2]\cdots w^{(n)}[j_n,i_n]\right)

ここで、\mathrm{Tr} はトレースを表し、内部自由度(ボンド次元)に関する縮約を意味します。

このような形式を行列積演算子 (Matrix Product Operator:MPO) と呼びます。これはベクトルを別なベクトルに変換する演算子として行列Wを考えるとき、Wが小さな部分状態毎の変換(i_k \to j_k)を行う行列をたくさん連ねた形式で近似したもの、という意味で行列積演算子と呼んでいます。行列積演算子といいつつ、それぞれの部分状態の変換を行う演算子は隣の演算子との積を取るためにボンド次元の大きさで繋がっており、実態はテンソルになっています。

実用例

論文中の Ⅲ.A.1 節でも扱っている、MNIST データセットに対して、FC2 モデルで学習する場合を具体的に説明します。
FC2 は2層の全結合(full-connect)層のみからなるネットワークです。ここで、MNIST データセットの入力は(28,28)の2次元画像であり、隠れ層は 256 次元として、0~9 のどの文字であるかの one-hot 表現を出力としているため
(784,256),(256,10) の2つの重み行列でそれぞれ表される全結合層からなります。

ここでは、(784,256) の重み行列 M を MPO 表現する例を考えます。式(1) に従い、I_kJ_k をそれぞれ (I_1,I_2,I_3,I_4)=(4,7,7,4)(J_1,J_2,J_3,J_4)=(4,4,4,4) に分解します。

そして、元の行列 M(I,J) 成分は以下で表されます。

W_{I,J} = \sum_{d_1,d_2, d_3}^D w^{(1)}_{i_1,i_1,d_1} w^{(2)}_{j_2,i_2,d_1,d_2} w^{(3)}_{j_3,i_3,d_2,d_3} w^{(4)}_{j_4,i_4,d_3}

ここで I=(i_1,i_2,i_3,i_4) は混合基数表現であり、(I_1,I_2,I_3,I_4)=(4,7,7,4) の場合は

I = i_1\cdot(7\cdot7\cdot4) + i_2\cdot(7\cdot4) + i_3\cdot4 + i_4

と表されます。

圧縮率

MPO 表現の前後で各全結合層でのパラメータの数を計算すると、

MPO 表現前:N_{\text{ori}} = N_x\cdot N_y
MPO 表現後:N_{\text{mpo}} = I_1J_1D+\sum_{k=2}^{n-1}I_kJ_kD^2 +I_nJ_nD

となります。
以下の結果の節では、モデル全体のパラメータ数を MPO 表現前後で比較し、圧縮率\rhoを以下で定義します。

\rho \equiv \frac{\sum_l N_{\text{mpo}}^{(l)}}{\sum_l N_{\text{ori}}^{(l)}}

結果

論文で報告されている実験は主に MNIST と CIFAR-10 の 2 種類のデータセット上で行われ、代表的なネットワーク(FC2, LeNet-5, VGG, ResNet, DenseNet)に対して MPO 表現を適用した結果が示されています。以下に主要な結果を要約します。
結果の要約にあたり、パラメータをどれだけ圧縮できたかの指標として、圧縮率\rhoを以下で定義し、それを用いて結果を説明します。

MNIST データセット

  • FC2: 入力を (784)、隠れ層を 256、出力を 10 とする 2 層の全結合モデルに対して、重み行列 (784,256) を MPO に置き換えた実験が行われています。ボンド次元 D を増やすとテスト精度は向上し、D=16 で通常の FC2 と同等の精度(論文図の例では通常の FC2 が約 98.35% ±0.2%)に達します。非常に小さな D(例: D=2)でも 1024 パラメータ程度で良好な性能を示し、パラメータ数は数百倍削減され得ることが示されました。

  • LeNet-5: 畳み込み+全結合の標準的な LeNet-5 について、最後の畳み込み層と 2 層の全結合層(計 3 層)を MPO に置き換えた実験が報告されています。MPO による圧縮後の圧縮率は約 0.05(5%)で、テスト精度は元の LeNet-5(99.17% ±0.04)とほぼ同等(MPO-Net: 99.17% ±0.08)でした。

CIFAR-10 データセット

より複雑な CIFAR-10 では、主に全結合層やパラメータの多い畳み込み層を MPO に置き換えて性能を検証しています。

  • VGG(VGG-16 / VGG-19): 最後の数層(重い畳み込み層と全結合層)を MPO に置き換えた結果、圧縮率は非常に小さく(論文では約 0.0005 と報告)、にもかかわらず MPO-Net の精度は元の VGG よりむしろ若干高い結果(例: VGG-16 元: 93.13% ±0.39 → MPO: 93.76% ±0.16)を示しました。これはパラメータ削減による学習の安定化や過学習の抑制が寄与している可能性が示唆されています。

  • ResNet: ResNet の最後の全結合層(例: 64k × 10 の重み行列)を MPO に置き換えたケースが示されており、k=4 の設定で圧縮率は約 0.11(11%)でした。深さを変えた各種 ResNet で MPO-Net は元の ResNet と同等のテスト精度を保ちました(MPO による ResUnit の圧縮も可能)。

  • DenseNet: DenseNet の全結合層を MPO で圧縮した複数の構成が比較されています。論文の表から抜粋すると、

    • Depth=40 (n=16,m=12,k=12): DenseNet 93.56% ±0.26 → MPO-Net 93.59% ±0.13, MPO 構造例 M^{1,5,2,1}_{4,4,7,4}(4), ρ ≈ 0.129
    • Depth=40 (n=16,m=12,k=24): DenseNet 95.12% ±0.15 → MPO-Net 95.13% ±0.13, ρ ≈ 0.089
    • Depth=100 (n=24,m=32,k=12): DenseNet 95.36% ±0.15 → MPO-Net 95.58% ±0.07, ρ ≈ 0.070
    • Depth=100 (n=96,m=32,k=24): DenseNet 95.74% ±0.09 → MPO-Net 96.09% ±0.07, ρ ≈ 0.044

    これらの結果から、MPO による圧縮は DenseNet でも有効であり、場合によっては精度が改善することもあることが示されています。

計算量、パラメータ数について

(この章の内容は論文中には明記されておらず、本記事の著者が見積もった値になります)

とある入力次元N_x, 出力次元N_yの全結合層の重み行列Wを MPO で表現し、N_x次元ベクトルvを変換する際の計算時間について考えます。ここで、ボンド次元DD\ll N_x,N_yを満たす定数とします。また、今後のためにN=\max(N_x,N_y)と定義します
この時、n個のテンソルからなる MPO 表現を行い、とあるn_0において、

N_x^{(L)} = \prod_{k=1}^{n_0} I_k,\\ N_x^{(R)} = \prod_{k=n_0+1}^n I_k,\\ N_y^{(L)} = \prod_{k=1}^{n_0} J_k,\\ N_y^{(R)} = \prod_{k=n_0+1}^n J_k

とあらわせるとします。ここで、N_x^{(L)}\simeq N_x^{(R)} = \mathcal{O}(N_x^\frac{1}{2}),N_y^{(L)}\simeq N_y^{(R)} = \mathcal{O}(N_y^\frac{1}{2}) を満たすとします。

このとき、Wvの計算は

  1. k=1\sim n_0のテンソルの縮約を計算し、(N_y^{(L)} \cdot D, N_x^{(L)}) に reshape する。この reshape した行列をW_Lとおく
  2. 上記のW_Lと、入力ベクトルv(N_x^{(L)},N_x^{(R)})行列に reshape したものの積を計算する。計算された行列を(N_y^{(L)} , D \cdot N_x^{(R)})に変形し、W_vとおく
  3. k=n_0\sim nのテンソルの縮約を計算し、(D\cdot N_x^{(R)}, N_y^{(R)}) に reshape する。この reshape した行列をW_Rとおく
  4. W_vW_Rの積を計算し、(N_y^{(L)} N_y^{(R)}) = (N_y) のベクトルに reshape する。

上記の計算において、計算量を見積もります。

1 において、端から順次テンソルの縮約をとる操作を考えます。このとき、最大の計算量は最後の縮約、すなわち(\prod_{k=1}^{n_0-1}I_kJ_k,D)テンソルと (I_{n_0},J_{n_0},D,D) テンソルの縮約であり、計算量は \prod_{k=1}^{n_0}I_kJ_k \cdot D^2 \simeq (N_xN_y)^{\frac{1}{2}} \cdot D^2 となります。

このとき、自明な定数倍程度の効果しか持たない I_k =1 の分解を禁止すると、テンソル数 n_0 <= \frac{1}{2} log_2(N) となるため、端から順に縮約をとる場合の縮約全体の計算量は\mathcal{O}(log(N)\cdot (N_xN_y)^{\frac{1}{2}} \cdot D^2) となります。
reshape の計算量はこれよりも小さいため、1 の操作の計算量は \mathcal{O}(log(N)\cdot (N_xN_y)^{\frac{1}{2}} \cdot D^2) です。

3 においても同様の計算が行われ、同じ計算量のオーダーとなります。

2 では、(N_y^{(L)} \cdot D, N_x^{(L)})(N_x^{(L)},N_x^{(R)}) の行列積の計算量が
\mathcal{O}(N_y^{\frac{1}{2}}N_x\cdot D) であり、reshape はこれより計算量が小さいため、全体の計算量は \mathcal{O}(N_y^{\frac{1}{2}}N_x\cdot D) です。

同様に、4 の主要な計算も行列積であり、計算量は \mathcal{O}(N_x^{\frac{1}{2}}N_y\cdot D) です。
そのため、全体の計算量は、DD\ll N_x,N_y を満たす定数とするとき、\mathcal{O}((N N_x N_y)^{\frac{1}{2}}\cdot D) となります。
簡単のため N_x=N_y=N とすると、通常のベクトル・行列積は\mathcal{O}(N^2)なのに対し、MPO-net では \mathcal{O}(N^\frac{3}{2}\cdot D) まで減らすことができます。

また、入力、出力の次元がそれぞれ N_x^{\frac{1}{k}}, N_y^{\frac{1}{k}} であるようなk 個のテンソルに分解する場合の総パラメータ数は

N_{\mathrm{mpo}} = 2 D (N_xN_y)^{\frac{1}{k}} + (k-2)D^2 (N_xN_y)^{\frac{1}{k}}

であるため、圧縮率は以下のようになります。

\rho_{\mathrm{mpo}} \approx \frac{(k-2)D^2}{(N_xN_y)^{\frac{k-1}{k}}}

これが最小となるのは、入力、出力側の次元が1になるような自明な変換がなく、できるだけたくさんのテンソルに分解した場合になります。すなわち、テンソル全ての物理脚の次元を 2 とする MPO に展開する場合で、簡単のため N_x=N_y=N =2^n を考えると全ての入力、出力側の次元を 2 に分解したケースになり、

\rho_{\mathrm{mpo}} \approx \frac{4 \log_2{N}D^2}{N^2}

となります。ここで、{(N_xN_y)^{\frac{k-1}{k}}}k=log N において 4 になることを使用しています。

ただし、細かく分解するとパラメータ数は減る代わりに表現能力は低くなるトレードオフがあり、細かく分解しても計算量は減らないため、実用上ここまで小さくするのは推奨されず数個程度の分割が好ましいと考えられます。

考察

MPO 表現が効果的である理由について

本文中の Disccusion の章では以下のような記載があります。

量子多体系の短距離相互作用の研究で MPO が成功を収めたことに触発され、本研究では NN 内の線形変換行列を MPO で表現する手法を提案しています。これは、画像のピクセル間の相関や、画像に潜む情報構造が本質的に局所的であるという仮定に基づいています。

この記載自体は正しく、画像データは本質的に局所的な構造を持つと考えられます。
では逆に画像処理の NN にしか MPO-net による近似が有効ではないかというと、そんなことはなく、MPO-net により様々な NN の重みパラメータはよく近似可能なのではないかと考えています。というのも、そもそも実データにおいては殆どのケースで何かしらの構造を持つデータを扱うためです。例えば自然言語処理を考えたとき、入力となる文は何らかの文法という構造に従います。

各 NN の重みを MPO 表現した際にそれが良い近似になっているかの指標としては、対象の重み行列(N_x,N_y)に対して、適当にN_x=N_x^1N_x^2, N_y=N_y^1N_y^2などと分解し、(N_x^1N_y^1,N_x^2N_y^2)と reshape した行列の特異値の分布が急峻になっていること、(テンソルネットワーク的にはエンタングルメントエントロピーが小さいこと)を確認すれば良いかと考えられます。もし特異値の分布が急峻であれば、それを適当な定数D(=ボンド次元)個までの状態のみで打ち切るような低ランク近似を行っても元の状態をよく再現するため、MPO による近似が有効だと考えられます。

今後の発展

論文中ではいくつか今後の発展に言及されています。MPO 表現した NN を量子物理学の文脈に載せて色々解析できるよね、といったもので、概略を述べると

  • MPO 表現した NN を量子物理学の諸手法で解析することで、ネットワークの複雑さを定量化できるのでは?
  • ネットワークの重みだけではなく入力データ自体も行列積状態(MPS)で圧縮方法できるのでは?
  • エンタングルメントエントロピーの視点から最適化問題を解析することができ、将来的には学習アルゴリズム改善につながるのでは?
    などが挙げられています。

これらの方針も非常に興味深いですが、個人的にはもっと NN 側の実践的なアプローチとして以下に興味があります。

  1. 学習済みモデルのパラメータ圧縮ではなく、事前学習に使えないか?
    論文中の方針では、あくまで MPO-net をそのまま使用することを考えていました。しかし、重みパラメータを MPO に置き換えたネットワークを最初に学習させ、MPO 表現された重み行列を、ボンド次元について縮約をとることで重み行列を復元し、MPO 表現されていないモデルの事前学習済みの重みとして用いることができるのではないかと考えています。
    この方針では

    • MPO-net により元々のモデルをそのまま学習するよりは軽量に事前学習ができる。
    • 事前学習された重みのパラメータ数は元々のモデルのパラメータ数より少ない。
    • 事前学習後、元のモデルの重みに変換してからファインチューニングすることで、元のモデルの表現能力を保持して実用できる。

    などのメリットがあると考えられます。

  2. ボンド次元 D を徐々に増やしながら学習させることで、最適な D を見つけられるのでは?
    ボンド次元 D は、MPO-net の表現能力とパラメータ数や計算時間とのトレードオフを司る重要なパラメータとなっています。そのため、十分な表現能力を持ちながら、できるだけ小さい D を見つけることが求められますが、一般に D の大きさを決めるための良い処方箋はなく、グリッドサーチ的に探索するしかない(※はず)です。(※何か処方箋があればぜひご教授ください)

    そこで、以下の方針を提案します。

    1. まず、小さい D=D_0 で学習させる。
    2. 続いて、D=D_0 で学習させた重みを初期値として、ボンド次元を少し増やしてD=D_1(>D_0) として再度学習させる。この際に追加する重みの初期値はほぼ 0 として、学習をしなければほぼ D=D_0 の MPO 重みに入力を作用させた値を出力として返す状態を初期値とする。
    3. 1,2 のように少しずつボンド次元を増やして学習を進め、ボンド次元を増やしてもコストが下がらなくなったら学習を止める。

    この方針では、小さなボンド次元での値を初期値として学習させることで、D を変更した際の再学習のコストが下げられるのでは?という点がポイントで、これによりある程度効率的に最適な D を発見できると考えています。

  3. より効率的な近似について
    MPO-net は非常に効率的な近似方法であり、例えば計算量が \mathcal{O}(N^2)\mathcal{O}\left(N^\frac{3}{2} D\right) に削減することができます。しかし、MPO の構造上端から端までボンド次元分の結合があることがネックとなり、これ以上の計算時間の削減は難しいです。
    より劇的な計算時間の削減が可能な手法として、Brick-Wall 構造のネットワークで近似することを提案している論文[2]があります。この論文についても需要があればいつか解説を行いたいと考えています。(※ただし、この論文[2]の手法は長距離相関を無視しすぎてしまっているように思われるので、適切に相関を入れる必要もあるのかなと個人的には感じています。)

実践

実際に MNIST の場合について MPO-net での計算を行い、ボンド次元を変えた際の精度を確認するノートブックを作成しました。
https://colab.research.google.com/drive/1KwTEEIE1XSe6RAT_vHXzTGTYEwn0FEJp?usp=sharing

雑記

解説記事を書くのは初めてで、色々読みにくい箇所などあるかもしれませんが、ご了承ください…。
重要な点をかいつまんで書いてたつもりなのですが、思っていたよりも分量が多くなりました。技術的な部分はできるだけ詳細に記載したつもりですが、畳み込みのあるモデルでの実験結果や、特に appendix には本記事には載せきれなかった興味深い内容がありますので、興味のある方はぜひ原論文をあたってみてください。

参考文献

[1]: Compressing deep neural networks by matrix product operators(arXiv: https://arxiv.org/abs/1904.06194)
[2]: Compressing Neural Networks Using Tensor Networks with Exponentially Fewer
Variational Parameters (arXiv: https://arxiv.org/abs/2305.06058)

Discussion