📚

Goで自動微分を実装する

に公開

はじめに

機械学習の仕組みを理解するため、Goで自動微分ライブラリを実装しました。この記事では、基本的な使い方、アプリケーション例、実装について紹介します。

実装にあたっては、「ゼロから作るDeep Learning❸」を読みながら進めました。試し読みもできるため、自動微分そのものや計算グラフの理論についてはこちらを参考いただければと思います。

https://www.oreilly.co.jp/books/9784873119069/

https://zenn.dev/koki0702/books/a2f1689cb67723433e22

基本的な使い方

変数と関数

例えば、sin(1.0)を計算する場合、以下のように書きます。

package main

import (
	"fmt"
	"math"

	F "github.com/itsubaki/autograd/function"
	"github.com/itsubaki/autograd/variable"
)

func main() {
	x := variable.New(1.0)
	y := F.Sin(x)

	fmt.Println(y)             // variable(0.8414709848078965)
	fmt.Println(math.Sin(1.0)) // 0.8414709848078965
}

Go Playground

Variableに実装されているBackward()を呼ぶことで、勾配が得られます。

func main() {
	x := variable.New(1.0)
	y := F.Sin(x)
	y.Backward()

	fmt.Println(y)             // variable(0.8414709848078965)
	fmt.Println(x.Grad)        // variable(0.5403023058681398)
	fmt.Println(math.Cos(1.0)) // 0.5403023058681398
}

Go Playground

計算結果はx.Gradに格納されます。このライブラリで提供されている関数(F.Sinなど)は、計算実行時に自分自身と入力xおよび出力yをリンクし、計算グラフを作成しています。y.Backward()では、出力yからそのグラフをたどって勾配を計算し、入力xx.Gradにその値を格納します。

合成関数

複数の関数を組み合わせた関数もまた、勾配を計算することができます。

func main() {
	x := variable.New(0.5)
	y := F.Square(F.Exp(F.Square(x)))
	y.Backward()

	fmt.Println(x.Grad) // variable(3.297442541400256)
}

Go Playground

これにより、機械学習で必要となる関数やレイヤーの勾配計算のコードを、大幅に削減することができます。例えば、LSTMレイヤーのBackwardを素直に実装すると以下のようになりますが、このようなコードは不要になります。

https://github.com/itsubaki/neu/blob/25ab4ef33ed56017a1b75d4ee0bb0866d82e002e/layer/lstm.go#L49-L77

Pythonでは、演算子のオーバーロードによって計算グラフを作成する演算も通常のスカラのように扱うことできます。Goではそれに相当する機能がないため、F.AddF.Mulのような明示的な関数呼び出しとして記述する必要があります。そのため、Pythonと比べると記述コストは高くなります。

func main() {
	matyas := func(x, y *variable.Variable) *variable.Variable {
		// 0.26 * (x ** 2 + y ** 2) - 0.48 * x * y
		z0 := F.MulC(0.26, F.Add(F.Pow(2.0)(x), F.Pow(2.0)(y)))
		z1 := F.MulC(0.48, F.Mul(x, y))
		return F.Sub(z0, z1)
	}

	x := variable.New(1.0)
	y := variable.New(1.0)
	z := matyas(x, y)
	z.Backward()

	fmt.Println(x.Grad) // variable(0.040000000000000036)
	fmt.Println(y.Grad) // variable(0.040000000000000036)
}

Go Playground

高階微分

x.GradVariableです。そのため、高階微分(微分の微分、さらにその微分...)が可能です。

func main() {
	x := variable.New(1.0)
	y := F.Sin(x)
	y.Backward(variable.Opts{
		CreateGraph: true,
	})

	gx := x.Grad
	x.Cleargrad()
	gx.Backward()

	fmt.Println(x.Grad)         // variable(-0.8414709848078965)
	fmt.Println(-math.Sin(1.0)) // -0.8414709848078965
}

Go Playground

cos(x)の微分は-sin(x)です。正しく計算できていることがわかります。Backward()の引数にCreateGraph: trueを指定すると、Backward()の中で行われている演算、つまり勾配計算自体に対しても計算グラフを作成します。そのグラフをたどることで、勾配の勾配を計算することが可能になります。高階微分は、現状一部の手法でのみ使用されるため、デフォルトはfalseになっています。

アプリケーション

勾配降下法

微分を使ったアプリケーションのひとつに勾配降下法があります。自動微分によって勾配を求め、少しずつ値を変化させ最小値を探索します。

func main() {
	rosenbrock := func(x0, x1 *variable.Variable) *variable.Variable {
		// 100 * (x1 - x0 ** 2) ** 2 + (x0 - 1) ** 2
		y0 := F.Pow(2.0)(F.Sub(x1, F.Pow(2.0)(x0)))
		y1 := F.Pow(2.0)(F.AddC(-1.0, x0))
		return F.Add(F.MulC(100, y0), y1)
	}

	update := func(lr float64, x ...*variable.Variable) {
		for _, v := range x {
			v.Data = tensor.F2(v.Data, v.Grad.Data, func(a, b float64) float64 {
				return a - lr*b
			})
		}
	}

	x0 := variable.New(0.0)
	x1 := variable.New(2.0)

	learningRate := 0.001
	for i := range 10001 {
		if i%1000 == 0 {
			fmt.Println(x0, x1)
		}

		y := rosenbrock(x0, x1)
		x0.Cleargrad()
		x1.Cleargrad()
		y.Backward()

		update(learningRate, x0, x1)
	}
}

Go Playground

rosenbrock関数が最小となるのは(1, 1)です。この例では、初期値(0, 2)から初めて、(0.994, 0.989)まで近づきます。tensor.F2は2つのテンソルを引数にとり、各要素ごとに指定した関数を適用します。ここでは、勾配の0.1%を元の値から減算しています。

機械学習フレームワーク

パラメータ(中身はVariableそのもの)と関数をまとめてレイヤーを作成し、それらを組み合わせてモデルを作ります。さらに、データローダやオプティマイザを追加すれば、機械学習のフレームワークを実装することができます。ここは、自動微分から離れるため詳しい解説は省略します。

以下は、LSTMでSinカーブを学習する例です。

dataset := NewCurve(N, noise, math.Sin)
dataloader := &DataLoader{
	BatchSize: batchSize,
	N:         dataset.N,
	Data:      dataset.Data,
	Label:     dataset.Label,
}

m := model.NewLSTM(hiddenSize, 1)
o := optimizer.SGD{
	LearningRate: 0.01,
}

for range epochs {
	m.ResetState()

	loss, count := variable.New(0), 0
	for x, t := range dataloader.Seq2() {
		y := m.Forward(x)
		loss = F.Add(loss, F.MeanSquaredError(y, t))

		if count++; count%bpttLength == 0 || count == dataset.N {
			m.Cleargrads()
			loss.Backward()
			loss.UnchainBackward()
			o.Update(m)
		}
	}
}

実装

「ゼロから作るDeep Learning❸」のDeZeroでは、Pythonによる自動微分ライブラリが実装されています。Goでも基本的に同様に実装できます。ここでは、主にDeZeroとの差分を紹介します。

行列演算

DeZeroではnumpyを使っていますが、今回は行列演算部分もフルスクラッチで実装しています。初期はvectormatrixパッケージを作りましたが、画像処理のConvolutionなどで必要となる3次元以上の多次元配列を扱えないため、改めてtensorパッケージとして作り直しています。

Function

DeZeroでは基底クラスを作り、継承することで様々な関数を実装しています。Goには継承はないため「インタフェースの埋め込み」で実装しました。これにより計算グラフを作成する共通処理と、個別の計算処理を分離しています。Forwardで自分自身fと入力xおよび出力yをリンクしています。

type Forwarder interface {
	Forward(x ...*Variable) []*Variable
	Backward(gy ...*Variable) []*Variable
}

type Function struct {
	Input, Output []*Variable
	Forwarder
}

func (f *Function) Forward(x ...*Variable) []*Variable {
	y := f.Forwarder.Forward(x...)

	// NOTE: create graph
	f.setCreator(y)
	f.Input, f.Output = x, y
	return y
}

func (f *Function) First(x ...*Variable) *Variable {
	return f.Forward(x...)[0]
}

例えば、F.SinFowarderインタフェースを実装したSinTFunctionに埋め込んでいます。

func Sin(x ...*Variable) *Variable {
	return (&Function{
		Forwarder: &SinT{},
	}).First(x...)
}

type SinT struct {
	x *Variable
}

func (f *SinT) Forward(x ...*Variable) []*Variable {
	f.x = x[0]

	y := tensor.Sin(x[0].Data)
	return []*Variable{
		From(y...),
	}
}

func (f *SinT) Backward(gy ...*Variable) []*Variable {
	return []*Variable{
		Mul(Cos(f.x), gy[0]), // cos(x) * gy
	}
}

Forwardでは単純なtensorとして計算を行いますが、Backwardでは計算グラフのリンクを作成するF.MulF.Cos関数を使います。これにより、勾配計算に対してもグラフが作成され、高階微分が可能になります。

Backward

自動微分の核となる部分です。「出力yから、そのyを生成した関数fをたどり、勾配を計算してx.Gradに格納する」を繰り返していきます。

func (v *Variable) Backward(opts ...Opts) {
	seen := make(map[*Function]bool)
	fs := addFunc(make([]*Function, 0), v.Creator, seen)
	for len(fs) > 0 {
		// 関数をpop。
		f := fs[len(fs)-1]
		fs = fs[:len(fs)-1]

		// 関数の出力から勾配を計算する。
		gys := grads(f.Output)		
		gxs := f.Backward(gys...)

		xs, gxs := zip(f.Input, gxs)
		for i, x := range xs {
			// 計算結果をx.Gradに追加する。
			x.Grad = add(x.Grad, gxs[i])

			// xが別の関数から作られていたならば、その関数も対象に追加する。
			if x.Creator != nil {
				fs = addFunc(fs, x.Creator, seen)
			}
		}
	}
}

NoGradとTestMode

DeZeroでは、勾配計算やテストモードのOn/OffをPythonのwith文を使って実装しています。Goではdeferを使って実装しました。学習済みのモデルを使う時など、勾配の計算が不要な場合に計算コストとメモリ消費を削減するために使用されます。

type Span struct {
	End func()
}

func Nograd() *Span {
	Config.EnableBackprop = false
	return &Span{
		End: func() {
			Config.EnableBackprop = true
		},
	}
}
func() {
	defer variable.Nograd().End()

	// 計算グラフは作成されない。
}()

ハマリどころ

ある程度作ったところでテストをしていると、計算結果が期待値とあいませんでした。計算グラフを描画してみると、本来あるべきリンクが途切れているのを発見しました。問題は、Backwardのこの部分です。

x.Grad = add(x.Grad, gxs[i])

Pythonでは、演算子のオーバーロードを実装した時点でxgrad + gxが自動的に計算グラフを作成する関数になります。Goでは単純に足し算を行うtensor.Addから、計算グラフを作成するF.Add関数に置き換える必要がありました。これに気づかずにしばらくハマってしまいました。(関数F.XYZ自体にバグがあると思い込んでいました。)

func add(xgrad, gx *Variable) *Variable {
	if xgrad == nil {
		return gx
	}

	// NOTE: create graph
	return Add(xgrad, gx)
}
func Add(x ...*Variable) *Variable {
	return (&Function{
		Forwarder: &AddT{},
	}).First(x...)
}

まとめ

Goで実装した自動微分のライブラリについて紹介しました。

  • 自動微分により合成関数のBackwardが不要になり、コードを大幅に削減できる。
  • Goには演算子のオーバーロードがないため、F.Addのような関数呼び出しが必要。Pythonと比べると記述コストが高めになる(明確さとのトレードオフだとは思います)。

リポジトリはこちらになります。

https://github.com/itsubaki/autograd

また、このライブラリを使用しGPTをフルスクラッチで実装したリポジトリもあわせて紹介します。

https://github.com/zakirullin/gpt-go

Discussion