行列積演算子を用いたニューラルネットの重みパラメータ数の削減
はじめに
この記事は論文[1] の解説です。要旨を一言で言うと、深層ニューラルネットワークにおける線形変換を行列積演算子(MPO)で表現することで、パラメータ数を大幅に削減しつつ予測精度を維持または向上させうる、という手法になります。
論文内の手法の解説
本論文の手法は図 1 および式(3),(5) に要約されます。以下ではこれらの図や式の意味を説明します。
考える設定
入力の次元
本論文の手法を用いる前の計算時間、パラメタ数
この線形変換を素朴に行うと、計算時間は
本論文の手法
論文中の式(4)のように、入力の次元および出力の次元を適当に分解します。
続いて、適当な定数
これらのテンソル群を用いて、元のパラメータ行列
ここで、
行列やテンソル計算に慣れた読者には、
上記を形式的に、以下のように表します(論文中式(5))
ここで、
このような形式を行列積演算子 (Matrix Product Operator:MPO) と呼びます。これはベクトルを別なベクトルに変換する演算子として行列
実用例
論文中の Ⅲ.A.1 節でも扱っている、MNIST データセットに対して、FC2 モデルで学習する場合を具体的に説明します。
FC2 は2層の全結合(full-connect)層のみからなるネットワークです。ここで、MNIST データセットの入力は(28,28)の2次元画像であり、隠れ層は 256 次元として、0~9 のどの文字であるかの one-hot 表現を出力としているため
(784,256),(256,10) の2つの重み行列でそれぞれ表される全結合層からなります。
ここでは、(784,256) の重み行列
そして、元の行列
ここで
と表されます。
圧縮率
MPO 表現の前後で各全結合層でのパラメータの数を計算すると、
MPO 表現前:
MPO 表現後:
となります。
以下の結果の節では、モデル全体のパラメータ数を MPO 表現前後で比較し、圧縮率
結果
論文で報告されている実験は主に MNIST と CIFAR-10 の 2 種類のデータセット上で行われ、代表的なネットワーク(FC2, LeNet-5, VGG, ResNet, DenseNet)に対して MPO 表現を適用した結果が示されています。以下に主要な結果を要約します。
結果の要約にあたり、パラメータをどれだけ圧縮できたかの指標として、圧縮率
MNIST データセット
-
FC2: 入力を
(784)、隠れ層を256、出力を10とする 2 層の全結合モデルに対して、重み行列(784,256)を MPO に置き換えた実験が行われています。ボンド次元 D を増やすとテスト精度は向上し、D=16 で通常の FC2 と同等の精度(論文図の例では通常の FC2 が約 98.35% ±0.2%)に達します。非常に小さな D(例: D=2)でも 1024 パラメータ程度で良好な性能を示し、パラメータ数は数百倍削減され得ることが示されました。 -
LeNet-5: 畳み込み+全結合の標準的な LeNet-5 について、最後の畳み込み層と 2 層の全結合層(計 3 層)を MPO に置き換えた実験が報告されています。MPO による圧縮後の圧縮率は約 0.05(5%)で、テスト精度は元の LeNet-5(99.17% ±0.04)とほぼ同等(MPO-Net: 99.17% ±0.08)でした。
CIFAR-10 データセット
より複雑な CIFAR-10 では、主に全結合層やパラメータの多い畳み込み層を MPO に置き換えて性能を検証しています。
-
VGG(VGG-16 / VGG-19): 最後の数層(重い畳み込み層と全結合層)を MPO に置き換えた結果、圧縮率は非常に小さく(論文では約 0.0005 と報告)、にもかかわらず MPO-Net の精度は元の VGG よりむしろ若干高い結果(例: VGG-16 元: 93.13% ±0.39 → MPO: 93.76% ±0.16)を示しました。これはパラメータ削減による学習の安定化や過学習の抑制が寄与している可能性が示唆されています。
-
ResNet: ResNet の最後の全結合層(例:
64k × 10の重み行列)を MPO に置き換えたケースが示されており、k=4 の設定で圧縮率は約 0.11(11%)でした。深さを変えた各種 ResNet で MPO-Net は元の ResNet と同等のテスト精度を保ちました(MPO による ResUnit の圧縮も可能)。 -
DenseNet: DenseNet の全結合層を MPO で圧縮した複数の構成が比較されています。論文の表から抜粋すると、
- Depth=40 (n=16,m=12,k=12): DenseNet 93.56% ±0.26 → MPO-Net 93.59% ±0.13, MPO 構造例
, ρ ≈ 0.129M^{1,5,2,1}_{4,4,7,4}(4) - Depth=40 (n=16,m=12,k=24): DenseNet 95.12% ±0.15 → MPO-Net 95.13% ±0.13, ρ ≈ 0.089
- Depth=100 (n=24,m=32,k=12): DenseNet 95.36% ±0.15 → MPO-Net 95.58% ±0.07, ρ ≈ 0.070
- Depth=100 (n=96,m=32,k=24): DenseNet 95.74% ±0.09 → MPO-Net 96.09% ±0.07, ρ ≈ 0.044
これらの結果から、MPO による圧縮は DenseNet でも有効であり、場合によっては精度が改善することもあることが示されています。
- Depth=40 (n=16,m=12,k=12): DenseNet 93.56% ±0.26 → MPO-Net 93.59% ±0.13, MPO 構造例
計算量、パラメータ数について
(この章の内容は論文中には明記されておらず、本記事の著者が見積もった値になります)
とある入力次元
この時、
とあらわせるとします。ここで、
このとき、
-
のテンソルの縮約を計算し、k=1\sim n_0 に reshape する。この reshape した行列を(N_y^{(L)} \cdot D, N_x^{(L)}) とおくW_L - 上記の
と、入力ベクトルW_L をv 行列に reshape したものの積を計算する。計算された行列を(N_x^{(L)},N_x^{(R)}) に変形し、(N_y^{(L)} , D \cdot N_x^{(R)}) とおくW_v -
のテンソルの縮約を計算し、k=n_0\sim n に reshape する。この reshape した行列を(D\cdot N_x^{(R)}, N_y^{(R)}) とおくW_R -
とW_v の積を計算し、W_R のベクトルに reshape する。(N_y^{(L)} N_y^{(R)}) = (N_y)
上記の計算において、計算量を見積もります。
1 において、端から順次テンソルの縮約をとる操作を考えます。このとき、最大の計算量は最後の縮約、すなわち
このとき、自明な定数倍程度の効果しか持たない
reshape の計算量はこれよりも小さいため、1 の操作の計算量は
3 においても同様の計算が行われ、同じ計算量のオーダーとなります。
2 では、
同様に、4 の主要な計算も行列積であり、計算量は
そのため、全体の計算量は、
簡単のため
また、入力、出力の次元がそれぞれ
であるため、圧縮率は以下のようになります。
これが最小となるのは、入力、出力側の次元が
となります。ここで、
ただし、細かく分解するとパラメータ数は減る代わりに表現能力は低くなるトレードオフがあり、細かく分解しても計算量は減らないため、実用上ここまで小さくするのは推奨されず数個程度の分割が好ましいと考えられます。
考察
MPO 表現が効果的である理由について
本文中の Disccusion の章では以下のような記載があります。
量子多体系の短距離相互作用の研究で MPO が成功を収めたことに触発され、本研究では NN 内の線形変換行列を MPO で表現する手法を提案しています。これは、画像のピクセル間の相関や、画像に潜む情報構造が本質的に局所的であるという仮定に基づいています。
この記載自体は正しく、画像データは本質的に局所的な構造を持つと考えられます。
では逆に画像処理の NN にしか MPO-net による近似が有効ではないかというと、そんなことはなく、MPO-net により様々な NN の重みパラメータはよく近似可能なのではないかと考えています。というのも、そもそも実データにおいては殆どのケースで何かしらの構造を持つデータを扱うためです。例えば自然言語処理を考えたとき、入力となる文は何らかの文法という構造に従います。
各 NN の重みを MPO 表現した際にそれが良い近似になっているかの指標としては、対象の重み行列
今後の発展
論文中ではいくつか今後の発展に言及されています。MPO 表現した NN を量子物理学の文脈に載せて色々解析できるよね、といったもので、概略を述べると
- MPO 表現した NN を量子物理学の諸手法で解析することで、ネットワークの複雑さを定量化できるのでは?
- ネットワークの重みだけではなく入力データ自体も行列積状態(MPS)で圧縮方法できるのでは?
- エンタングルメントエントロピーの視点から最適化問題を解析することができ、将来的には学習アルゴリズム改善につながるのでは?
などが挙げられています。
これらの方針も非常に興味深いですが、個人的にはもっと NN 側の実践的なアプローチとして以下に興味があります。
-
学習済みモデルのパラメータ圧縮ではなく、事前学習に使えないか?
論文中の方針では、あくまで MPO-net をそのまま使用することを考えていました。しかし、重みパラメータを MPO に置き換えたネットワークを最初に学習させ、MPO 表現された重み行列を、ボンド次元について縮約をとることで重み行列を復元し、MPO 表現されていないモデルの事前学習済みの重みとして用いることができるのではないかと考えています。
この方針では- MPO-net により元々のモデルをそのまま学習するよりは軽量に事前学習ができる。
- 事前学習された重みのパラメータ数は元々のモデルのパラメータ数より少ない。
- 事前学習後、元のモデルの重みに変換してからファインチューニングすることで、元のモデルの表現能力を保持して実用できる。
などのメリットがあると考えられます。
-
ボンド次元
を徐々に増やしながら学習させることで、最適なD を見つけられるのでは?D
ボンド次元 は、MPO-net の表現能力とパラメータ数や計算時間とのトレードオフを司る重要なパラメータとなっています。そのため、十分な表現能力を持ちながら、できるだけ小さいD を見つけることが求められますが、一般にD の大きさを決めるための良い処方箋はなく、グリッドサーチ的に探索するしかない(※はず)です。(※何か処方箋があればぜひご教授ください)D そこで、以下の方針を提案します。
- まず、小さい
で学習させる。D=D_0 - 続いて、
で学習させた重みを初期値として、ボンド次元を少し増やしてD=D_0 として再度学習させる。この際に追加する重みの初期値はほぼD=D_1(>D_0) として、学習をしなければほぼ0 の MPO 重みに入力を作用させた値を出力として返す状態を初期値とする。D=D_0 - 1,2 のように少しずつボンド次元を増やして学習を進め、ボンド次元を増やしてもコストが下がらなくなったら学習を止める。
この方針では、小さなボンド次元での値を初期値として学習させることで、
を変更した際の再学習のコストが下げられるのでは?という点がポイントで、これによりある程度効率的に最適なD を発見できると考えています。D - まず、小さい
-
より効率的な近似について
MPO-net は非常に効率的な近似方法であり、例えば計算量が →\mathcal{O}(N^2) に削減することができます。しかし、MPO の構造上端から端までボンド次元分の結合があることがネックとなり、これ以上の計算時間の削減は難しいです。\mathcal{O}\left(N^\frac{3}{2} D\right)
より劇的な計算時間の削減が可能な手法として、Brick-Wall 構造のネットワークで近似することを提案している論文[2]があります。この論文についても需要があればいつか解説を行いたいと考えています。(※ただし、この論文[2]の手法は長距離相関を無視しすぎてしまっているように思われるので、適切に相関を入れる必要もあるのかなと個人的には感じています。)
実践
実際に MNIST の場合について MPO-net での計算を行い、ボンド次元を変えた際の精度を確認するノートブックを作成しました。
雑記
解説記事を書くのは初めてで、色々読みにくい箇所などあるかもしれませんが、ご了承ください…。
重要な点をかいつまんで書いてたつもりなのですが、思っていたよりも分量が多くなりました。技術的な部分はできるだけ詳細に記載したつもりですが、畳み込みのあるモデルでの実験結果や、特に appendix には本記事には載せきれなかった興味深い内容がありますので、興味のある方はぜひ原論文をあたってみてください。
参考文献
[1]: Compressing deep neural networks by matrix product operators(arXiv: https://arxiv.org/abs/1904.06194)
[2]: Compressing Neural Networks Using Tensor Networks with Exponentially Fewer
Variational Parameters (arXiv: https://arxiv.org/abs/2305.06058)
Discussion