🦴

「Deep Learning is Singular, and That's Good」論文とpyro実装を読む

2023/12/20に公開

「Deep Learning is Singular, and That's Good」(論文)とpyro実装を読む

https://qiita.com/advent-calendar/2023/ppl
確率的プログラミング言語 Advent Calendar 2023の20日目の記事です。

Deep Learning is Singular, and That's Good」という野心的なタイトルの論文があり、そこでは特異学習理論で重要になる量であるRLCT(real log canonical threshold,実対数しきい値)の3層ニューラルネットにおける値の計算をpyroで実装しているのでそれを紹介します。

書かれた人,Daniel Murfetさんのページを見ると代数幾何と機械学習の関係、特異学習理論のほかにも圏論や論理、チューリングマシンなどの計算機科学の基礎など幅広く研究されているようです。
Quantifying degeneracy in singular models via the learning coefficientというより明確にニューラルネットのような特異性(縮退)を持ったモデルにおけるRLCT(学習係数)を扱った論文も最近出されました。

情報量基準(BIC,WBIC)の復習

情報量基準は統計モデル選択の基準となる量であり尤度Lとモデルの自由度dに対して

  • 赤池情報量基準(AIC) -2\ln L+2d
  • ベイズ情報量基準(BIC) -2\ln L+d\ln n

などがあり、特にベイズ情報量基準はモデルを統計物理的系としてみたときの自由エネルギーに対応します。
BICの導出ではFisher情報行列が正則であることを仮定としますが多くの統計や機械学習ででてくるモデルはこの仮定を満たさず、そのためそれを拡張した
WBIC=\ln L+\lambda \ln n
という量が使われることになります。ここでλは実対数閾値(RLCT)と呼ばれる量です。
自由エネルギーは統計力学に置いては有限要素数、統計学に置いては有限なサンプル数の場合の補正項がBICに現れる次元d/2の項、WBICのλに相当します。[1]

WBICとRLCTの意味付けについては
https://www.alignmentforum.org/s/czrXjvCLsqGepybHC/p/4eZtmwaqhAgdJQDEg
の記述が参考になります。
以上の式の説明に加え数値計算に関してはAIC, WAIC, WBICを実感するがとても参考になります。

RLCT(実対数閾値)

渡辺澄夫先生の"Algebraic Geometry and Statistical Learning Theory"(gray book)によればRLCTは
\lambda:=\min_\alpha \min_{1 \leq j \leq d} (\frac{h_j+1}{k_j})

ただしh_j,k_jはデータの真の分布q(x)とパラメーターwをもつ統計モデル(尤度)p(x|w)の間のKLダイバージェンスK(w)の特異点解消写像gに対する指数
K(g(u))= S u_1^{k_1} u_2^{k_2} \cdots u_d^{k_d}

g'(u)= b(u)u_1^{h_1} u_2^{h_2} \cdots u_d^{h_d}

と定義され、KLダイバージェンスのゼータ関数
\zeta(z)\int K(w)^z\phi(z)dw
(\phi(z)はパラメーターの事前分布)の最大の極の指数に対応します。以下のようにいくつかの統計、機械学習モデルに対して同書や関連論文で計算がされています。

縮小ランクモデル

多次元の回帰モデルのうち
y=BAx + \epsilon
のように係数パラメーターが行列BAと分解できるものを縮小ランク回帰
といいます。線形のAutoEncoderのような形と言えるかもしれません。
http://watanabe-www.math.dis.titech.ac.jp/users/swatanab/red_rank.html

Stochastic complexities of reduced rank regression in Bayesian estimationでは
の場合のRLCTを求めています。線形、2層相当の場合でも非常に複雑で行列A,Bを複数のブロックに分けて各要素を再帰的にblow upする必要が有ります。
https://speakerdeck.com/xiangze/stochastic-complexities-of-reduced-rank-regression-in-bayesian-estimationnozheng-ming-gai-lue
以前論文を読んだときのノートですがかえってわかりづらいかもしれません。

3層ニューラルネット

hoxo_mさんがニューラルネットなどgray bookに出てくるモデルのRLCTを解析的、数値的に計算されています。

https://www.alignmentforum.org/posts/xRWsfGfvDAjRWXcnG/dslt-0-distilling-singular-learning-theory#Literature
ではDaniel Murfetさんの学生さんの修士論文が紹介されていて2層で活性化関数がtanhのニューラルネットなどが研究され、RLCTが解析的に算出されています。

pyroによるRLCTの数値計算

MCMC(NUTS HMC)でニューラルネットの最終層、その一つ前の予測誤差を数値的に算出する計算がpytorch, pyroで実装されているようです。
https://github.com/suswei/RLCT

pyro_example.pyの内容が短くまとまっていて
https://github.com/suswei/RLCT/blob/e9e04ca5e64250dfbb94134ec5283286dcdc4358/pyro_example.py
model関数でニューラルネットを定義し、main関数からget_data_symmetric関数でデータを生成し、それを用いてrun_inferenceでMCMCを実行して事後分布を求めています。その後複数の温度に対してexpected_nll_posterior関数で対数尤度を計算して 温度の係数として線型回帰でRCLTを求めています。ただしexsampleでは2つの温度での自由エネルギーの値を結んだ線から算出するという粗いものになっておりmain.pyから辿れる
https://github.com/suswei/RLCT/blob/master/main.py
ではもっと本格的な計算をしています。setup_w0で真のパラメーター分布、get_dataset_by_idでデータを得てlambda_asymptoticsで線型回帰でRLCTの推定を行っています。そのなかのapproxinf_nllで対数尤度をMCMCあるいは変分法で計算しています。

同時にラプラス近似での計算も行いそこでは計算値がNaNになり破綻することも示しています。

大規模な行列のHessianの算出に関してはModular Block-diagonal Curvature Approximations for Feedforward Architectures コード にあるそうです。

loss flat minima,汎化との関係

ディープラーニングではloss関数の値が高次元のパラメーター空間の中でlocal minumumに落ち込まないで大域的な最適解に近い解に落ち着くことが知られております。これに関してはhttps://www.iwanami.co.jp/book/b570597.htmlでもディープラーニングの謎として書かれ拡散モデルについて思ったこと、統計モデリング等との関係でも他の数理モデルとの類似点について少し触れました。

はHessiannが縮退し、変数の自由度よりも小さくなる様子がBICの次元数の項とWBICのRLCTの項の違いにより捉えられていると言えるかもしれません。
しかし広い裾分布を持ったような特異なモデルではflat minimaが存在し、大域的に極小値がつながっているとは即座には言えないようにも思えます。

特異点とその解消に関して

実数値のKLダイバージェンスを持つニューラルネットでのRLCTの計算は難しい一方、
3次元複素曲面\mathbb{C}^3の場合の孤立特異点は分類はリー環と同じくDynkin図形によってなされ特にMcKayグラフなどとも呼ばれているようです。
https://ja.wikipedia.org/wiki/デュ・バル特異点

A_n: x^2+y^2+z^{n+1}=0
D_n: x^2+y^2z+z^{n+1}=0
E_6: x^2+y^3+z^4=0
E_7: x^2+y^3+z^3=0
E_8: x^2+y^3+z^5=0
x,y,z \in \mathbb{C}

A_n,D_n,E_6,E_7,E_8のそれぞれに対して特異点解消を考えることができ、(実でない)対数閾値を計算できその結果が公表されています。

https://arxiv.org/abs/2312.16187
結論としては次数nに対して\frac{n+1}{n}となるようです。複素数はその上で定義される方程式が必ず解を持つという完全体という性質を持ちそれが対数閾値の計算を簡単にしているところがあります。統計や学習理論では普通は実数の確率分布関数を考えるので対数閾値の値も変わってくることに注意が必要です(gray bookの3.6末尾にも説明があります)。代数幾何学では統計モデルに対応する代数多様体を(小平次元、算術種数、幾何種数など)様々な指標で分類していますがそれらがどのようにRLCTや汎化誤差に影響するのかはまだわかっていないようです。

その他参考

脚注
  1. (https://xiangze.hatenablog.com/entry/2014/10/13/224825#fn-232bfd2c に対する答えになっているでしょうか。) ↩︎

Discussion