👓

Lensだけで作るニューラルネットワーク

2021/12/15に公開

これは、FOLIO Advent calendar 2021 の15日目の記事です。


圏論を機械学習に応用する話題の一つとしてLensで微分可能プログラミングを実装する話を紹介したいと思います。とはいえ圏論など気にせずLensを使ったニューラルネットワークを実装していきます。学習モデル、誤差関数、学習係数などの基本的な構成要素が全てLens(ParaLens)として実装できる様子を楽しんでいただければと思っています。

Lensって何?

Lensはいわゆる getter と setter を組み合わせたデータ構造です。すなわち型sのデータ型から型aの値を取り出すgetter s -> a と、型sのデータ型を型aの値で更新して新しい型sのデータ型を作成するsetter (s, a) -> s から成っています。

type Lens s a = (s -> a, (s, a) -> s)

最も簡単なLensの例としてタプルの要素にアクセスするLensを実装してみましょう。

-- | タプルの1要素目にアクセスするLens
_1 :: Lens (a, b) a
_1 = (fst, \((_, b), c) -> (c, b))

-- | タプルの2要素目にアクセスするLens
_2 :: Lens (a, b) b
_2 = (snd, \((a, _), c) -> (a, c))

単純にgetter/setterの実装を組み合わせただけですね。
これらのLensを便利に使うために必要なコンビネータを用意します。

-- | Lensのgetterを使用するためのコンビネータ
(^.) :: s -> Lens s a -> a
s ^. (getter, _) = getter s

-- | Lensのsetterを使用するためのコンビネータ
(.~) :: Lens s a -> a -> s -> s
(_, setter) .~ a = \s -> setter (s, a)

実際にLensを使ってデータ型の操作をしてみましょう。

> import Data.Function ((&))

> -- getter の使用例
> (123, "abc") ^. _1
123
> (123, "abc") ^. _2
"abc"

> -- setter の使用例
> (123, "abc") & _1 .~ 456
(456,"abc")
> (123, "abc") & _2 .~ "def"
(123,"def")

期待通りに動いていますね 👏
さらにLensには合成可能であるという重要な性質があります。

-- | 2つのLensを合成するためのコンビネータ
(%) :: Lens a b -> Lens b c -> Lens a c
(getter1, setter1) % (getter2, setter2) = (getter2 . getter1, \(a, c) -> setter1 (a, (setter2 (getter1 a, c))))

実装は一見すると複雑ですが以下のような図で考えるととてもシンプルです。

まず Lens a b のgetterをg :: a -> b, setterをs :: (a, b) -> aとして関数を図示すると以下のように表せます。

データの流れ、すなわち矢印の向きが合うように2つのLensを並べると以下のようになります。

2つ並べたLensの構成要素を以下のように整理すると合成された1つのLensが出来上がることが分かります。

Lensの合成の実装(%)はこの図の通りに配線を行なっているだけなのです。この合成可能であるという性質はLensを圏論を用いて考える際にも非常に中心的な役割を果たします[1]

Lensの合成を使えばネストの深いタプルの一部の値だけを置き換える処理も簡単に書くことができます。

> ((123, 456), "abc") & (_1 % _2) .~ 789
((123,789),"abc")

特にイミュータブルなデータ構造を取り扱う時はこのような入り組んだデータ型の値を更新するプログラムは複雑になりやすいのでLensが重宝されることになるでしょう。

少し脱線にはなりますが、データ分析にもLensを応用することができます。

-- | データの平均値にアクセスするLens
average :: Lens [Double] Double
average = (getter, setter) where
  getter xs      = sum xs / (fromIntegral $ length xs)
  setter (xs, a) = let _average = getter xs
                    in fmap (+ (a - _average)) xs
> [1,2,3] ^. average
2.0
> [1,2,3] & average .~ 0
[-1.0,0.0,1.0]

データから統計値を取り出すgetterと統計値が与えられた値になるようにデータを修正するsetterを考えるとLensを作ることができるわけです。例えばaverageと同様に標準偏差にアクセスするLens std を実装すれば、データの平均を0分散を1にするような正規化もLensを使って簡単に行うことができます。

微分可能プログラミング

ニューラルネットワークの学習において誤差関数を含む学習モデルをパラメータによって微分することは大切な工程です。ニューラルネットワークのように微分可能な基本パーツをプログラムによって組み立て、大規模な微分可能なモデルを作る手法は微分可能プログラミングと呼ばれています。最近ではニューラルネットワークのような単なる関数だけではなく、微分可能な基本パーツとしてソート関数[2]や二次計画法も含む最適化計算[3]等も使えるようになってきており、深層学習のリブランディングを超えた新しいパラダイムとして発展している印象です。

さて微分可能プログラミングにおいて最も基本となるn次元ベクトルからm次元ベクトルへの関数の微分を考えてみましょう。

f : {\mathbb R}^n \rightarrow {\mathbb R}^m

この関数の微分を考えると、定義域である{\mathbb R}^nの各点にヤコビ行列と呼ばれるm\times n行列を対応させる関数{\mathbb R}^n \to M(m, n)となります(ただしM(m,n)はm行n列の行列全体からなる空間です)。
今、誤差逆伝播法を考えたいのでこのm\times n行列を{\mathbb R}^m \rightarrow {\mathbb R} ^ nの線形写像として考えると、結局fの微分R[f]

R[f] : {\mathbb R}^n \times {\mathbb R}^m \rightarrow {\mathbb R}^n

という関数と考えることができます。

ところでこのfR[f]を一緒に見て何か気づかないでしょうか?そうLensの型と同じ形をしていますね。もう一度Lensの型を見てみましょう。

type Lens s a = (s -> a, (s, a) -> s)

s{\mathbb R}^na{\mathbb R}^mと考えると、fR[f]の組み合わせでLensを作ることが出来そうです。

Para構成

ニューラルネットワークの実装に入る前にもう一つ必要な準備があります。

教師あり学習のモデルを関数として考える際には入出力だけでなくパラメータも重要な概念です。パラメータの次元をpとするとモデルは以下のような関数として表現されます。

f : {\mathbb R}^p \times {\mathbb R}^n \rightarrow {\mathbb R}^m

この表現は{\mathbb R}^{n+p}を直積に分けただけとも考えられるので、先程と同様にこの関数の微分R[f]を考えると

R[f] : {\mathbb R}^p \times {\mathbb R}^n \times {\mathbb R}^m \rightarrow {\mathbb R}^p \times {\mathbb R}^n

のようになります。

Lensにおいてもこの形を特殊な型として定義しておきましょう。

type ParaLens p x y = Lens (p, x) y

さらにq次元のパラメータを持つ別のモデル

g : {\mathbb R}^q \times {\mathbb R}^m \rightarrow {\mathbb R}^k

があったとしましょう。パラメータを除いて考えるとfの出力とgの入力の次元はmとして一致しているので、これらのモデルを合成して新たなモデル

g \circ f : {\mathbb R}^{p+q} \times {\mathbb R}^n \rightarrow {\mathbb R}^k

を作ることができるはずです。こういったモデルの合成としては例えば三層パーセプトロンが考えられます。

ParaLensの合成をLensと比較するとパラメータの部分が今までと違う挙動をしていることが分かります。ParaLensに対する特殊な合成を以下のように実装しておきましょう。

(%.) :: ParaLens p x y -> ParaLens q y z -> ParaLens (p, q) x z
(f1, rf1) %. (f2, rf2) = (f3, rf3) where
  f3 ((p, q), x) = f2 (q, f1 (p, x))
  rf3 (((p, q), x), z) = let (q', y') = rf2 ((q, f1 (p, x)), z)
                             (p', x') = rf1 ((p, x), y')
                          in ((p', q'), x')

合成されたモデルの微分はこれまでと同様

R[g \circ f] : {\mathbb R}^{p+q} \times {\mathbb R}^n \times {\mathbb R}^k \rightarrow {\mathbb R}^{p+q} \times {\mathbb R}^n

として与えられます。ParaLensの合成も同様の挙動になっていることが分かります。

ニューラルネットワークを実装する

それではこれまでに定義したParaLensを使って簡単なニューラルネットワークを作ってみましょう。以下では線形代数ライブラリとしてhmatrix、特にNumeric.LinearAlgebra.Staticを使っています。

ニューラルネットワーク、特に全結合層と呼ばれるレイヤーは線形変換Wとバイアスb、そして活性化関数\sigma(ここではReLUを使います)の合成で成り立っています。

x_{n+1} = \sigma\left(Wx_n+b\right)

これを

linear %. bias %. relu

と実装できるようにParaLensを定義していきましょう。

まずは一番簡単な bias の実装です。

bias :: KnownNat n => ParaLens (R n) (R n) (R n)
bias = (getter, setter) where
  getter  (b, x)     = x + b
  setter ((b, x), y) = (y, y)

R nn次元ベクトルを表す型です。getter は単純にバイアスを加算するだけ、setter はパラメータ・入力それぞれの微分が単位行列になるので出力 y がそれぞれそのまま出力される形になっています。実のところsetterの実装はgetterの自動微分を行えば簡単に手に入りますが、この記事では複雑なモデルは扱わないので手で実装することにします。

次に linear の実装を見てみましょう。

linear :: (KnownNat m, KnownNat n) => ParaLens (L m n) (R n) (R m)
linear = (getter, setter) where
  getter  (w, x)     = w #> x
  setter ((w, x), y) = (y `outer` x, tr w #> y)

L m nm\times n行列を表す型で(#>)行列とベクトルの積を計算する関数、outerはベクトル同士の外積、trは行列の転置を行う関数です。getter は単純に行列とベクトルの積を計算、setterはパラメータに関する微分は出力と入力の外積、入力に関する微分は係数行列の転置を出力に掛け算したものになります。

最後にReLUを実装しましょう。

relu :: KnownNat n => ParaLens () (R n) (R n)
relu = (getter, setter) where
  getter  ((), x)     = dvmap (max 0) x
  setter (((), x), y) = ((), y * (dvmap step x))
  step x = if x > 0 then 1 else 0

dvmapはベクトルの各成分に変換用の関数を適用する関数です。ReLUはパラメータを持たないのでパラメータはUnit型になっています。getter は dvmap という関数を使ってベクトルの各要素に0と比較して大きい方を返す関数を適用しています。setterは入力の微分と出力の成分毎の掛け算を行います。

以上の実装を組合せると例えば以下のようなレイヤーが定義できます。

layer :: (KnownNat m, KnownNat n) => ParaLens _ (R n) (R m)
layer = linear %. bias %. relu

layer の型の中で_となっている部分は実際の型は ((L m n, R m), ()) ですが、手で書くのは大変なので PartialTypeSignatures 拡張を利用して型推論に丸投げしています。layerはまたParaLensになっているのでさらに合成することが可能です。

layer @4 %. layer

@4TypeApplications拡張を利用して中間層のユニット数を指定しています。このように型推論を活用することで各レイヤーの次元の指定を最小限の記述で行うことができます。

XOR回路を学習する

それでは実装したニューラルネットワークを使って実際に学習を行ってみましょう。しかし学習を行うにはまだ足りないものがあります。それは誤差関数と学習係数の実装です。

まず誤差関数ですがXOR回路は出力がカテゴリ変数なので交差エントロピー誤差を使うことにします。

crossEntropyLoss :: KnownNat n => ParaLens (R n) (R n) Double
crossEntropyLoss = (getter, setter) where
  getter  (y', y) = log (sumElements . unwrap $ dvmap exp y) - sumElements (unwrap $ y' * y)
  setter ((y', y), z) =
    let expY = dvmap exp y
        sumExpY = sumElements (unwrap expY)
     in (dvmap (*z) (-y), dvmap (*z) (dvmap (/sumExpY) expY - y'))

softmax :: forall n. KnownNat n => ParaLens () (R n) (R n)
softmax = (getter, setter) where
  getter ((), x) =
    let xMax = maxElement (unwrap x)
        expX = dvmap exp (x - konst xMax)
        denom = sumElements (unwrap expX)
     in dvmap (/denom) expX
  setter (((), x), y) =
    let n = fromIntegral $ natVal (Proxy @n)
        z = getter ((), x)
        Just cols = create . fromColumns $ replicate n (unwrap z)
     in ((), (cols * (eye - (tr cols))) #> y)

交差エントロピー誤差は定義通りに実装すると微分を計算する時に計算が不安定になってしまうため、ソフトマックス関数と組み合わせたものを定義しています。そのためcrossEntropyLossを使って学習したモデルは、softmaxと組み合わせて使うことになります。ここで定義した2つの関数もまたParaLensであることに注目してください。誤差関数のパラメータに対応する部分には正解データが入ることを想定しています。

最後に学習係数を実装しましょう。

learningRate :: ParaLens () Double ()
learningRate = (const (), setter) where
  setter (((), loss), ()) = ((), (-0.01) * loss)

学習係数は単純なユニット型以外何も出力しませんが、setterは与えられた誤差に係数を掛けてフィードバックするという折り返し地点のような役割を果たします。これもまたParaLensとして実装されています。

以上で定義した誤差関数と学習係数を使ってモデルと学習データが与えられた時にパラメータをアップデートする関数を定義しましょう。

updateParam :: (Parameter p, KnownNat n) => ParaLens p a (R n) -> (a, R n) -> p -> p
updateParam model a b p =
  let l = model %. crossEntropyLoss %. learningRate
      (((p', _), ()), _) = (((p, b), ()), a) & l .~ ()
   in update p p'

train :: (Parameter p, KnownNat n) => ParaLens p a (R n) -> [(a, R n)] -> p -> p
train model dataset initParam = foldl (\p d -> updateParam model d p) initParam dataset

updateParamは学習データを1つ使ってパラメータを更新する関数です。この関数はまず第一引数としてパラメータpを持ち入力としてa、出力としてR nを持つモデルを取ります。出力の型だけ固定されているのはcrossEntropyLossを使うことを予め想定しているからです。次に第二引数として入出力の正解データのペアを1つ取ります。そして残りの関数の型としては単純にパラメータを更新する関数とみなすことができます。trainupdateParamを繰り返し適用することでリストとして与えられた複数の学習データを用いてパラメータの更新を行っています。

Parameter型クラスは以下のように定義された型クラスです。

class Parameter a where
  update :: a -> a -> a

instance Parameter () where
  update () () = ()

instance KnownNat n => Parameter (R n) where
  update !v !w = v + w

instance (KnownNat m, KnownNat n) => Parameter (L m n) where
  update !a !b = a + b

instance (Parameter a, Parameter b) => Parameter (a, b) where
  update (!a, !b) (!a', !b') = (update a a', update b b')

パラメータとその差分が与えられた時にどうやって更新するのかを制御し、パラメータがタプルとしてネストしていても機能するようになっています。

これまでに定義した実装を用いてXOR回路を学習してみましょう。

model :: ParaLens _ (R 2) (R 2)
model = linear @4 %. bias %. relu %. linear %. bias

eval :: ParaLens p a b -> p -> a -> b
eval model params input = (params, input) ^. model

学習モデルであるmodelは入力の次元が2, 出力の次元が2, 中間層の次元が4のニューラルネットワークです。evalは学習モデルを評価するための関数でgetterの適用の仕方を見やすいように書き換えただけの関数です。

学習用の正解データを用意しましょう。

dataset :: [(R 2, R 2)]
dataset = take 10000 $ cycle
  [ (vec2 0 0, vec2 1 0)
  , (vec2 0 1, vec2 0 1)
  , (vec2 1 0, vec2 0 1)
  , (vec2 1 1, vec2 1 0)
  ]

XOR回路なので2つの値が異なる時だけ出力の1次元目が1になるように作成しています。

これらを使ってモデルの学習を行い学習結果を評価してみましょう。

main :: IO ()
main = do
  [bias1] <- toRows <$> randn @1
  weight1 <- randn
  [bias2] <- toRows <$> randn @1
  weight2 <- randn

  let initParam = ((((weight1, bias1), ()), weight2), bias2)
      trained = train model dataset initParam

  putStrLn "~~~ Result ~~~"
  putStr "(0, 0): "
  printf "%.3f\n" . getOneProb $ eval (model %. softmax) (trained, ()) (vec2 0 0)
  putStr "(0, 1): "
  printf "%.3f\n" . getOneProb $ eval (model %. softmax) (trained, ()) (vec2 0 1)
  putStr "(1, 0): "
  printf "%.3f\n" . getOneProb $ eval (model %. softmax) (trained, ()) (vec2 1 0)
  putStr "(1, 1): "
  printf "%.3f\n" . getOneProb $ eval (model %. softmax) (trained, ()) (vec2 1 1)
  putStrLn "=============="
  where
  getOneProb = (<.> (vec2 0 1))

実行してみると期待通りに学習が行われていることが分かります。

$ stack run
~~~ Result ~~~
(0, 0): 0.012
(0, 1): 0.979
(1, 0): 0.979
(1, 1): 0.024
==============

Lens、正確にはParaLensだけを用いてニューラルネットワークを学習するプログラムがほとんど実装てきてしまうのは驚きではないでしょうか。実はAdamを始めとした各種オプティマイザーもParaLensとして実装することができるのですが今回は割愛しました。ここまでに実装したコードはgistにまとめていますので参考にしてみてください。

おわりに

今回紹介したLensを使ったニューラルネットワークの実装は

という論文の内容をHaskellに書き下したものです。

Lensという概念はProfunctor Opticsという名前で一般化され圏論を使った研究対象になっています[4]。上記の論文では関数をパラメータ化して考えるPara構成とLensの組み合わせから学習モデルの圏であるLearnが構成できることを示しています。

Learn

という論文で提案された教師あり学習における学習モデルから成る圏です。この圏は実はゲーム理論における合成可能なゲームから成るGameという圏に埋め込めることが知られており、圏論によって機械学習とゲーム理論という異なる分野が関連するという興味深い話になっています。例えばこの対応によって学習により収束したパラメータは選択均衡(ナッシュ均衡の一般化)と対応することが分かります[5]

圏論を応用した機械学習の研究は様々ありますが今年の6月に投稿された

というサーベイがまとまっていますので、興味がある人は是非読んでみてください。

機械学習の研究に圏論が使われるモチベーションの一つに複雑化する概念を整理・統一したいというものがあるでしょう。例えば教師あり学習一つを取っても、パラメータpを持つ関数y = f_p(x)としての学習モデル、その中から最適なモデルを評価するための誤差関数、学習時の振る舞いを制御する学習係数やオプティマイザーなど様々な概念が現れます。これらの概念は一つ一つが詳しく研究されていますが、「学習」するという機能全体を考える時には全てを結合・合成して用いる必要があります。圏論を使えばこのような機械学習を構成する要素の有機的な合成それ自体を研究対象にすることができ、「学習」に対する深い洞察を得られることが期待できるかもしれません。

「学習する」というフレーズは、時には最適な関数を求めることであったり、時には最適な確率分布を求めることであったり、また時には最適なグラフ構造を求めることであったりと、同じ言葉でも様々な機械学習の手法の中でそれぞれ違った意味で広く用いられています。もし圏論というフレームワークを用いてこれらの手法を理解することができたとすれば「学習する」という概念も何らかの圏論的な概念を用いて統一的に記述されることとなるでしょう。その時には「学習する」という行為に対してもう一歩理解を進めることができたと言えるのではないでしょうか。

Lensは双方向のデータフローを合成可能な形でうまく表現できるため、今回紹介した微分可能プログラミングを始め確率的プログラミング等いわゆる帰納プログラミングの考え方ととても相性がいいのではないかと思っています。ソフトウェア2.0の世界でProfunctor Opticsが活躍する未来が来たら面白そうですね。

脚注
  1. この記事におけるLensの合成の中で、左側のLensのgetterであるg1が2回使われているのは計算が重複しており効率的ではありません。この問題はProfunctor Opticsなどのより抽象的なLensの定義に基づくことで解決することができます(こちらの事実は"Categorical Foundations of Gradient-Based Learning"の著者の一人であるBruno Gavranović氏にTwitterで教えていただきました)。 ↩︎

  2. Cuturi, Marco, Olivier Teboul, and Jean-Philippe Vert. "Differentiable ranks and sorting using optimal transport." arXiv preprint arXiv:1905.11885 (2019). ↩︎

  3. Amos, Brandon, and J. Zico Kolter. "Optnet: Differentiable optimization as a layer in neural networks." International Conference on Machine Learning. PMLR, 2017. ↩︎

  4. Clarke, Bryce, et al. "Profunctor optics, a categorical update." arXiv preprint arXiv:2001.07488 (2020). ↩︎

  5. Hedges, Jules. "From open learners to open games." arXiv preprint arXiv:1902.08666 (2019). ↩︎

Discussion