「Deep Learning is Singular, and That's Good」論文とpyro実装を読む
「Deep Learning is Singular, and That's Good」(論文)とpyro実装を読む
確率的プログラミング言語 Advent Calendar 2023の20日目の記事です。
「Deep Learning is Singular, and That's Good」という野心的なタイトルの論文があり、そこでは特異学習理論で重要になる量であるRLCT(real log canonical threshold,実対数しきい値)の3層ニューラルネットにおける値の計算をpyroで実装しているのでそれを紹介します。
書かれた人,Daniel Murfetさんのページを見ると代数幾何と機械学習の関係、特異学習理論のほかにも圏論や論理、チューリングマシンなどの計算機科学の基礎など幅広く研究されているようです。
Quantifying degeneracy in singular models via the learning coefficientというより明確にニューラルネットのような特異性(縮退)を持ったモデルにおけるRLCT(学習係数)を扱った論文も最近出されました。
情報量基準(BIC,WBIC)の復習
情報量基準は統計モデル選択の基準となる量であり尤度Lとモデルの自由度dに対して
- 赤池情報量基準(AIC)
-2\ln L+2d - ベイズ情報量基準(BIC)
-2\ln L+d\ln n
などがあり、特にベイズ情報量基準はモデルを統計物理的系としてみたときの自由エネルギーに対応します。
BICの導出ではFisher情報行列が正則であることを仮定としますが多くの統計や機械学習ででてくるモデルはこの仮定を満たさず、そのためそれを拡張した
という量が使われることになります。ここでλは実対数閾値(RLCT)と呼ばれる量です。
自由エネルギーは統計力学に置いては有限要素数、統計学に置いては有限なサンプル数の場合の補正項がBICに現れる次元d/2の項、WBICのλに相当します。[1]
WBICとRLCTの意味付けについては
以上の式の説明に加え数値計算に関してはAIC, WAIC, WBICを実感するがとても参考になります。
RLCT(実対数閾値)
渡辺澄夫先生の"Algebraic Geometry and Statistical Learning Theory"(gray book)によればRLCTは
ただし
と定義され、KLダイバージェンスのゼータ関数
(
縮小ランクモデル
多次元の回帰モデルのうち
のように係数パラメーターが行列BAと分解できるものを縮小ランク回帰
といいます。線形のAutoEncoderのような形と言えるかもしれません。
Stochastic complexities of reduced rank regression in Bayesian estimationでは
の場合のRLCTを求めています。線形、2層相当の場合でも非常に複雑で行列A,Bを複数のブロックに分けて各要素を再帰的にblow upする必要が有ります。
以前論文を読んだときのノートですがかえってわかりづらいかもしれません。
3層ニューラルネット
hoxo_mさんがニューラルネットなどgray bookに出てくるモデルのRLCTを解析的、数値的に計算されています。
ではDaniel Murfetさんの学生さんの修士論文が紹介されていて2層で活性化関数がtanhのニューラルネットなどが研究され、RLCTが解析的に算出されています。
pyroによるRLCTの数値計算
MCMC(NUTS HMC)でニューラルネットの最終層、その一つ前の予測誤差を数値的に算出する計算がpytorch, pyroで実装されているようです。
pyro_example.pyの内容が短くまとまっていてexpected_nll_posterior関数で対数尤度を計算して 温度の係数として線型回帰でRCLTを求めています。ただしexsampleでは2つの温度での自由エネルギーの値を結んだ線から算出するという粗いものになっておりmain.pyから辿れる ではもっと本格的な計算をしています。setup_w0で真のパラメーター分布、get_dataset_by_idでデータを得てlambda_asymptoticsで線型回帰でRLCTの推定を行っています。そのなかのapproxinf_nllで対数尤度をMCMCあるいは変分法で計算しています。
model関数でニューラルネットを定義し、main関数からget_data_symmetric関数でデータを生成し、それを用いてrun_inferenceでMCMCを実行して事後分布を求めています。その後複数の温度に対して同時にラプラス近似での計算も行いそこでは計算値がNaNになり破綻することも示しています。
大規模な行列のHessianの算出に関してはModular Block-diagonal Curvature Approximations for Feedforward Architectures コード にあるそうです。
loss flat minima,汎化との関係
ディープラーニングではloss関数の値が高次元のパラメーター空間の中でlocal minumumに落ち込まないで大域的な最適解に近い解に落ち着くことが知られております。これに関してはhttps://www.iwanami.co.jp/book/b570597.htmlでもディープラーニングの謎として書かれ拡散モデルについて思ったこと、統計モデリング等との関係でも他の数理モデルとの類似点について少し触れました。
はHessiannが縮退し、変数の自由度よりも小さくなる様子がBICの次元数の項とWBICのRLCTの項の違いにより捉えられていると言えるかもしれません。
しかし広い裾分布を持ったような特異なモデルではflat minimaが存在し、大域的に極小値がつながっているとは即座には言えないようにも思えます。
特異点とその解消に関して
実数値のKLダイバージェンスを持つニューラルネットでのRLCTの計算は難しい一方、
3次元複素曲面
その他参考
- Distilling Singular Learning Theory
- 実対数閾値の定義に使われているゼータ関数はなぜゼータ関数と呼ばれているのか
- Singularity and its resolution
- The Du Val singularities An, Dn, E6, E7, E8 Milnor fibre ∼ resolution!
- Estimating Real Log Canonical Thresholds
-
(https://xiangze.hatenablog.com/entry/2014/10/13/224825#fn-232bfd2c に対する答えになっているでしょうか。) ↩︎
Discussion