Lensだけで作るニューラルネットワーク
これは、FOLIO Advent calendar 2021 の15日目の記事です。
圏論を機械学習に応用する話題の一つとしてLensで微分可能プログラミングを実装する話を紹介したいと思います。とはいえ圏論など気にせずLensを使ったニューラルネットワークを実装していきます。学習モデル、誤差関数、学習係数などの基本的な構成要素が全てLens(ParaLens)として実装できる様子を楽しんでいただければと思っています。
Lensって何?
Lensはいわゆる getter と setter を組み合わせたデータ構造です。すなわち型s
のデータ型から型a
の値を取り出すgetter s -> a
と、型s
のデータ型を型a
の値で更新して新しい型s
のデータ型を作成するsetter (s, a) -> s
から成っています。
type Lens s a = (s -> a, (s, a) -> s)
最も簡単なLensの例としてタプルの要素にアクセスするLensを実装してみましょう。
-- | タプルの1要素目にアクセスするLens
_1 :: Lens (a, b) a
_1 = (fst, \((_, b), c) -> (c, b))
-- | タプルの2要素目にアクセスするLens
_2 :: Lens (a, b) b
_2 = (snd, \((a, _), c) -> (a, c))
単純にgetter/setterの実装を組み合わせただけですね。
これらのLensを便利に使うために必要なコンビネータを用意します。
-- | Lensのgetterを使用するためのコンビネータ
(^.) :: s -> Lens s a -> a
s ^. (getter, _) = getter s
-- | Lensのsetterを使用するためのコンビネータ
(.~) :: Lens s a -> a -> s -> s
(_, setter) .~ a = \s -> setter (s, a)
実際にLensを使ってデータ型の操作をしてみましょう。
> import Data.Function ((&))
> -- getter の使用例
> (123, "abc") ^. _1
123
> (123, "abc") ^. _2
"abc"
> -- setter の使用例
> (123, "abc") & _1 .~ 456
(456,"abc")
> (123, "abc") & _2 .~ "def"
(123,"def")
期待通りに動いていますね 👏
さらにLensには合成可能であるという重要な性質があります。
-- | 2つのLensを合成するためのコンビネータ
(%) :: Lens a b -> Lens b c -> Lens a c
(getter1, setter1) % (getter2, setter2) = (getter2 . getter1, \(a, c) -> setter1 (a, (setter2 (getter1 a, c))))
実装は一見すると複雑ですが以下のような図で考えるととてもシンプルです。
まず Lens a b
のgetterをg :: a -> b
, setterをs :: (a, b) -> a
として関数を図示すると以下のように表せます。
データの流れ、すなわち矢印の向きが合うように2つのLensを並べると以下のようになります。
2つ並べたLensの構成要素を以下のように整理すると合成された1つのLensが出来上がることが分かります。
Lensの合成の実装(%)
はこの図の通りに配線を行なっているだけなのです。この合成可能であるという性質はLensを圏論を用いて考える際にも非常に中心的な役割を果たします[1]。
Lensの合成を使えばネストの深いタプルの一部の値だけを置き換える処理も簡単に書くことができます。
> ((123, 456), "abc") & (_1 % _2) .~ 789
((123,789),"abc")
特にイミュータブルなデータ構造を取り扱う時はこのような入り組んだデータ型の値を更新するプログラムは複雑になりやすいのでLensが重宝されることになるでしょう。
少し脱線にはなりますが、データ分析にもLensを応用することができます。
-- | データの平均値にアクセスするLens
average :: Lens [Double] Double
average = (getter, setter) where
getter xs = sum xs / (fromIntegral $ length xs)
setter (xs, a) = let _average = getter xs
in fmap (+ (a - _average)) xs
> [1,2,3] ^. average
2.0
> [1,2,3] & average .~ 0
[-1.0,0.0,1.0]
データから統計値を取り出すgetterと統計値が与えられた値になるようにデータを修正するsetterを考えるとLensを作ることができるわけです。例えばaverage
と同様に標準偏差にアクセスするLens std
を実装すれば、データの平均を0分散を1にするような正規化もLensを使って簡単に行うことができます。
微分可能プログラミング
ニューラルネットワークの学習において誤差関数を含む学習モデルをパラメータによって微分することは大切な工程です。ニューラルネットワークのように微分可能な基本パーツをプログラムによって組み立て、大規模な微分可能なモデルを作る手法は微分可能プログラミングと呼ばれています。最近ではニューラルネットワークのような単なる関数だけではなく、微分可能な基本パーツとしてソート関数[2]や二次計画法も含む最適化計算[3]等も使えるようになってきており、深層学習のリブランディングを超えた新しいパラダイムとして発展している印象です。
さて微分可能プログラミングにおいて最も基本となるn次元ベクトルからm次元ベクトルへの関数の微分を考えてみましょう。
この関数の微分を考えると、定義域である
今、誤差逆伝播法を考えたいのでこの
という関数と考えることができます。
ところでこの
type Lens s a = (s -> a, (s, a) -> s)
s
をa
を
Para構成
ニューラルネットワークの実装に入る前にもう一つ必要な準備があります。
教師あり学習のモデルを関数として考える際には入出力だけでなくパラメータも重要な概念です。パラメータの次元を
この表現は
のようになります。
Lensにおいてもこの形を特殊な型として定義しておきましょう。
type ParaLens p x y = Lens (p, x) y
さらに
があったとしましょう。パラメータを除いて考えると
を作ることができるはずです。こういったモデルの合成としては例えば三層パーセプトロンが考えられます。
ParaLensの合成をLensと比較するとパラメータの部分が今までと違う挙動をしていることが分かります。ParaLensに対する特殊な合成を以下のように実装しておきましょう。
(%.) :: ParaLens p x y -> ParaLens q y z -> ParaLens (p, q) x z
(f1, rf1) %. (f2, rf2) = (f3, rf3) where
f3 ((p, q), x) = f2 (q, f1 (p, x))
rf3 (((p, q), x), z) = let (q', y') = rf2 ((q, f1 (p, x)), z)
(p', x') = rf1 ((p, x), y')
in ((p', q'), x')
合成されたモデルの微分はこれまでと同様
として与えられます。ParaLensの合成も同様の挙動になっていることが分かります。
ニューラルネットワークを実装する
それではこれまでに定義したParaLensを使って簡単なニューラルネットワークを作ってみましょう。以下では線形代数ライブラリとしてhmatrix、特にNumeric.LinearAlgebra.Staticを使っています。
ニューラルネットワーク、特に全結合層と呼ばれるレイヤーは線形変換
これを
linear %. bias %. relu
と実装できるようにParaLensを定義していきましょう。
まずは一番簡単な bias
の実装です。
bias :: KnownNat n => ParaLens (R n) (R n) (R n)
bias = (getter, setter) where
getter (b, x) = x + b
setter ((b, x), y) = (y, y)
R n
はy
がそれぞれそのまま出力される形になっています。実のところsetterの実装はgetterの自動微分を行えば簡単に手に入りますが、この記事では複雑なモデルは扱わないので手で実装することにします。
次に linear
の実装を見てみましょう。
linear :: (KnownNat m, KnownNat n) => ParaLens (L m n) (R n) (R m)
linear = (getter, setter) where
getter (w, x) = w #> x
setter ((w, x), y) = (y `outer` x, tr w #> y)
L m n
は(#>)
行列とベクトルの積を計算する関数、outer
はベクトル同士の外積、tr
は行列の転置を行う関数です。getter は単純に行列とベクトルの積を計算、setterはパラメータに関する微分は出力と入力の外積、入力に関する微分は係数行列の転置を出力に掛け算したものになります。
最後にReLUを実装しましょう。
relu :: KnownNat n => ParaLens () (R n) (R n)
relu = (getter, setter) where
getter ((), x) = dvmap (max 0) x
setter (((), x), y) = ((), y * (dvmap step x))
step x = if x > 0 then 1 else 0
dvmap
はベクトルの各成分に変換用の関数を適用する関数です。ReLUはパラメータを持たないのでパラメータはUnit型になっています。getter は dvmap
という関数を使ってベクトルの各要素に0と比較して大きい方を返す関数を適用しています。setterは入力の微分と出力の成分毎の掛け算を行います。
以上の実装を組合せると例えば以下のようなレイヤーが定義できます。
layer :: (KnownNat m, KnownNat n) => ParaLens _ (R n) (R m)
layer = linear %. bias %. relu
layer
の型の中で_
となっている部分は実際の型は ((L m n, R m), ())
ですが、手で書くのは大変なので PartialTypeSignatures
拡張を利用して型推論に丸投げしています。layer
はまたParaLensになっているのでさらに合成することが可能です。
layer @4 %. layer
@4
はTypeApplications
拡張を利用して中間層のユニット数を指定しています。このように型推論を活用することで各レイヤーの次元の指定を最小限の記述で行うことができます。
XOR回路を学習する
それでは実装したニューラルネットワークを使って実際に学習を行ってみましょう。しかし学習を行うにはまだ足りないものがあります。それは誤差関数と学習係数の実装です。
まず誤差関数ですがXOR回路は出力がカテゴリ変数なので交差エントロピー誤差を使うことにします。
crossEntropyLoss :: KnownNat n => ParaLens (R n) (R n) Double
crossEntropyLoss = (getter, setter) where
getter (y', y) = log (sumElements . unwrap $ dvmap exp y) - sumElements (unwrap $ y' * y)
setter ((y', y), z) =
let expY = dvmap exp y
sumExpY = sumElements (unwrap expY)
in (dvmap (*z) (-y), dvmap (*z) (dvmap (/sumExpY) expY - y'))
softmax :: forall n. KnownNat n => ParaLens () (R n) (R n)
softmax = (getter, setter) where
getter ((), x) =
let xMax = maxElement (unwrap x)
expX = dvmap exp (x - konst xMax)
denom = sumElements (unwrap expX)
in dvmap (/denom) expX
setter (((), x), y) =
let n = fromIntegral $ natVal (Proxy @n)
z = getter ((), x)
Just cols = create . fromColumns $ replicate n (unwrap z)
in ((), (cols * (eye - (tr cols))) #> y)
交差エントロピー誤差は定義通りに実装すると微分を計算する時に計算が不安定になってしまうため、ソフトマックス関数と組み合わせたものを定義しています。そのためcrossEntropyLoss
を使って学習したモデルは、softmax
と組み合わせて使うことになります。ここで定義した2つの関数もまたParaLens
であることに注目してください。誤差関数のパラメータに対応する部分には正解データが入ることを想定しています。
最後に学習係数を実装しましょう。
learningRate :: ParaLens () Double ()
learningRate = (const (), setter) where
setter (((), loss), ()) = ((), (-0.01) * loss)
学習係数は単純なユニット型以外何も出力しませんが、setterは与えられた誤差に係数を掛けてフィードバックするという折り返し地点のような役割を果たします。これもまたParaLensとして実装されています。
以上で定義した誤差関数と学習係数を使ってモデルと学習データが与えられた時にパラメータをアップデートする関数を定義しましょう。
updateParam :: (Parameter p, KnownNat n) => ParaLens p a (R n) -> (a, R n) -> p -> p
updateParam model a b p =
let l = model %. crossEntropyLoss %. learningRate
(((p', _), ()), _) = (((p, b), ()), a) & l .~ ()
in update p p'
train :: (Parameter p, KnownNat n) => ParaLens p a (R n) -> [(a, R n)] -> p -> p
train model dataset initParam = foldl (\p d -> updateParam model d p) initParam dataset
updateParam
は学習データを1つ使ってパラメータを更新する関数です。この関数はまず第一引数としてパラメータp
を持ち入力としてa
、出力としてR n
を持つモデルを取ります。出力の型だけ固定されているのはcrossEntropyLoss
を使うことを予め想定しているからです。次に第二引数として入出力の正解データのペアを1つ取ります。そして残りの関数の型としては単純にパラメータを更新する関数とみなすことができます。train
はupdateParam
を繰り返し適用することでリストとして与えられた複数の学習データを用いてパラメータの更新を行っています。
Parameter
型クラスは以下のように定義された型クラスです。
class Parameter a where
update :: a -> a -> a
instance Parameter () where
update () () = ()
instance KnownNat n => Parameter (R n) where
update !v !w = v + w
instance (KnownNat m, KnownNat n) => Parameter (L m n) where
update !a !b = a + b
instance (Parameter a, Parameter b) => Parameter (a, b) where
update (!a, !b) (!a', !b') = (update a a', update b b')
パラメータとその差分が与えられた時にどうやって更新するのかを制御し、パラメータがタプルとしてネストしていても機能するようになっています。
これまでに定義した実装を用いてXOR回路を学習してみましょう。
model :: ParaLens _ (R 2) (R 2)
model = linear @4 %. bias %. relu %. linear %. bias
eval :: ParaLens p a b -> p -> a -> b
eval model params input = (params, input) ^. model
学習モデルであるmodel
は入力の次元が2, 出力の次元が2, 中間層の次元が4のニューラルネットワークです。eval
は学習モデルを評価するための関数でgetterの適用の仕方を見やすいように書き換えただけの関数です。
学習用の正解データを用意しましょう。
dataset :: [(R 2, R 2)]
dataset = take 10000 $ cycle
[ (vec2 0 0, vec2 1 0)
, (vec2 0 1, vec2 0 1)
, (vec2 1 0, vec2 0 1)
, (vec2 1 1, vec2 1 0)
]
XOR回路なので2つの値が異なる時だけ出力の1次元目が1になるように作成しています。
これらを使ってモデルの学習を行い学習結果を評価してみましょう。
main :: IO ()
main = do
[bias1] <- toRows <$> randn @1
weight1 <- randn
[bias2] <- toRows <$> randn @1
weight2 <- randn
let initParam = ((((weight1, bias1), ()), weight2), bias2)
trained = train model dataset initParam
putStrLn "~~~ Result ~~~"
putStr "(0, 0): "
printf "%.3f\n" . getOneProb $ eval (model %. softmax) (trained, ()) (vec2 0 0)
putStr "(0, 1): "
printf "%.3f\n" . getOneProb $ eval (model %. softmax) (trained, ()) (vec2 0 1)
putStr "(1, 0): "
printf "%.3f\n" . getOneProb $ eval (model %. softmax) (trained, ()) (vec2 1 0)
putStr "(1, 1): "
printf "%.3f\n" . getOneProb $ eval (model %. softmax) (trained, ()) (vec2 1 1)
putStrLn "=============="
where
getOneProb = (<.> (vec2 0 1))
実行してみると期待通りに学習が行われていることが分かります。
$ stack run
~~~ Result ~~~
(0, 0): 0.012
(0, 1): 0.979
(1, 0): 0.979
(1, 1): 0.024
==============
Lens、正確にはParaLensだけを用いてニューラルネットワークを学習するプログラムがほとんど実装てきてしまうのは驚きではないでしょうか。実はAdamを始めとした各種オプティマイザーもParaLensとして実装することができるのですが今回は割愛しました。ここまでに実装したコードはgistにまとめていますので参考にしてみてください。
おわりに
今回紹介したLensを使ったニューラルネットワークの実装は
という論文の内容をHaskellに書き下したものです。
Lensという概念はProfunctor Opticsという名前で一般化され圏論を使った研究対象になっています[4]。上記の論文では関数をパラメータ化して考えるPara構成とLensの組み合わせから学習モデルの圏である
という論文で提案された教師あり学習における学習モデルから成る圏です。この圏は実はゲーム理論における合成可能なゲームから成る
圏論を応用した機械学習の研究は様々ありますが今年の6月に投稿された
というサーベイがまとまっていますので、興味がある人は是非読んでみてください。
機械学習の研究に圏論が使われるモチベーションの一つに複雑化する概念を整理・統一したいというものがあるでしょう。例えば教師あり学習一つを取っても、パラメータ
「学習する」というフレーズは、時には最適な関数を求めることであったり、時には最適な確率分布を求めることであったり、また時には最適なグラフ構造を求めることであったりと、同じ言葉でも様々な機械学習の手法の中でそれぞれ違った意味で広く用いられています。もし圏論というフレームワークを用いてこれらの手法を理解することができたとすれば「学習する」という概念も何らかの圏論的な概念を用いて統一的に記述されることとなるでしょう。その時には「学習する」という行為に対してもう一歩理解を進めることができたと言えるのではないでしょうか。
Lensは双方向のデータフローを合成可能な形でうまく表現できるため、今回紹介した微分可能プログラミングを始め確率的プログラミング等いわゆる帰納プログラミングの考え方ととても相性がいいのではないかと思っています。ソフトウェア2.0の世界でProfunctor Opticsが活躍する未来が来たら面白そうですね。
-
この記事におけるLensの合成の中で、左側のLensのgetterである
g1
が2回使われているのは計算が重複しており効率的ではありません。この問題はProfunctor Opticsなどのより抽象的なLensの定義に基づくことで解決することができます(こちらの事実は"Categorical Foundations of Gradient-Based Learning"の著者の一人であるBruno Gavranović氏にTwitterで教えていただきました)。 ↩︎ -
Cuturi, Marco, Olivier Teboul, and Jean-Philippe Vert. "Differentiable ranks and sorting using optimal transport." arXiv preprint arXiv:1905.11885 (2019). ↩︎
-
Amos, Brandon, and J. Zico Kolter. "Optnet: Differentiable optimization as a layer in neural networks." International Conference on Machine Learning. PMLR, 2017. ↩︎
-
Clarke, Bryce, et al. "Profunctor optics, a categorical update." arXiv preprint arXiv:2001.07488 (2020). ↩︎
-
Hedges, Jules. "From open learners to open games." arXiv preprint arXiv:1902.08666 (2019). ↩︎
Discussion