🎄

Juliaの自動微分を簡単に比べてみる

2024/12/22に公開

Julia Advent Calendar 2024 シリーズ2 の22日目です🎄
https://qiita.com/advent-calendar/2024/julia

Juliaの自動微分を簡単に比べてみる

この記事ではJulia言語の自動微分パッケージ

を簡単に比べます.

パッケージ

同じ名前の関数が干渉するのでusingではなくimportで読み込みます.

パッケージの読み込み
# import Pkg
# Pkg.add("ForwardDiff")
# Pkg.add("Zygote")
# Pkg.add("Enzyme")
import ForwardDiff
import Zygote
import Enzyme

例題

次の1変数関数 f(x) の微分係数 \frac{\mathrm{d}f}{\mathrm{d}x}(5)=10 と2変数関数 g(x_1, x_2) の勾配 \nabla g(5,5) = (10,30) を求めます.

f(x) = x^2
g(x_1, x_2) = {x_1}^2 + 3{x_2}^2
1変数関数の例
f(x) = x^2
2変数関数の例
g(x) = x[1]^2 + 3*x[2]^2

ForwardDiff.jl

ForwardDiff.jlは読み込みが速いので個人的に気に入っています.

1変数関数の例
julia> ForwardDiff.derivative(f, 5)
10
2変数関数の例
julia> ForwardDiff.gradient(g, [5,5])
2-element Vector{Int64}:
 10
 30

Zygote.jl

Zygote.jlJulia in Physics 2021 OnlineにてSatoshi Terasakiさんによる講演で紹介されました. 単純な関数だとForwardDiff.jlの方が速いように思われますが, Wkipediaの解説の通り, 機械学習では効率的になるのかもしれません.

1変数関数の例
julia> Zygote.gradient(f, 5)[1]
10.0
2変数関数の例
julia> Zygote.gradient(g, [5,5])[1]
2-element Vector{Float64}:
 10.0
 30.0

Enzyme.jl

Enzyme.jl筑波大学 計算科学研究センター 第146回計算科学コロキウムJulia in Physics 2024での講演のために来日して頂いたValentin Churavyさんが開発に携わっているパッケージです. 日本語の詳しい説明はJulia Advent Calendar 2024永井さんの記事を参照して下さい. 他の言語でも使用できるとのことです.

1変数関数の例
julia> Enzyme.autodiff(Enzyme.Reverse, f, Enzyme.Active(5.0))[1][1]
10.0
1変数関数の例
julia> Enzyme.gradient(Enzyme.Reverse, f, 5.0)
10.0
2変数関数の例
julia> Enzyme.gradient(Enzyme.Reverse, g, [5.0, 5.0])
2-element Vector{Float64}:
 10.0
 30.0

速度

面白い結果が出ました. 1変数関数の場合は ForwardDiff.jlZygote.jl の方が速く, 2変数関数の場合は Enzyme.jl の方が速いようです.

1変数関数の例
using BenchmarkTools
@btime ForwardDiff.derivative(f, 5.0)
@btime Zygote.gradient(f, 5.0)[1]
@btime Enzyme.autodiff(Enzyme.Reverse, f, Enzyme.Active(5.0))[1][1]
@btime Enzyme.gradient(Enzyme.Reverse, f, 5.0)

# 出力
  1.200 ns (0 allocations: 0 bytes)
  1.200 ns (0 allocations: 0 bytes)
  5.400 ns (0 allocations: 0 bytes)
  9.419 ns (1 allocation: 16 bytes)
2変数関数の例
@btime ForwardDiff.gradient(g, [5.0, 5.0])
@btime Zygote.gradient(g, [5.0, 5.0])[1]
@btime Enzyme.gradient(Enzyme.Reverse, g, [5.0, 5.0])

# 出力
  580.337 ns (5 allocations: 368 bytes)
  501.117 ns (19 allocations: 1.14 KiB)
  42.281 ns (2 allocations: 160 bytes)

一応, 1変数関数で項を増やしたものと, 変数を増やして4変数にしたもので検証しておきます.

1変数関数の例
@btime ForwardDiff.derivative(x -> x^2 + x^3 + x^4 + sin(x), 5)
@btime Zygote.gradient(x -> x^2 + x^3 + x^4 + sin(x), 5)[1]
@btime Enzyme.autodiff(Enzyme.Reverse, x -> x^2 + x^3 + x^4 + sin(x), Enzyme.Active(5.0))[1][1]
@btime Enzyme.gradient(Enzyme.Reverse, x -> x^2 + x^3 + x^4 + sin(x), 5.0)

# 出力
  1.200 ns (0 allocations: 0 bytes)
  1.200 ns (0 allocations: 0 bytes)
  15.230 ns (0 allocations: 0 bytes)
  17.936 ns (1 allocation: 16 bytes)
4変数関数の例
@btime ForwardDiff.gradient(x -> x[1]^2 + x[2]^3 + x[3]^4 + sin(x[4]), [5.0, 5.0, 5.0, 5.0])
@btime Zygote.gradient(x -> x[1]^2 + x[2]^3 + x[3]^4 + sin(x[4]), [5.0, 5.0, 5.0, 5.0])[1]
@btime Enzyme.gradient(Enzyme.Reverse, x -> x[1]^2 + x[2]^3 + x[3]^4 + sin(x[4]), [5.0, 5.0, 5.0, 5.0])

# 出力
  607.955 ns (5 allocations: 704 bytes)
  1.400 μs (36 allocations: 2.31 KiB)
  54.010 ns (2 allocations: 192 bytes)

まとめ

1変数関数の場合は

多変数関数の場合は

がおすすめです. ただし, Enzyme.jlはLLVMベースのため他の言語でも使用できるなど, 速度以外のアドバンテージがあることに注意してください.

バージョン情報

Julia v1.10.5
ForwardDiff v0.10.36
Zygote v0.6.73
Enzyme v0.12.36

https://gist.github.com/ohno/24126924230b769336038f3ebe1905a8

Discussion