Juliaの自動微分を簡単に比べてみる
Julia Advent Calendar 2024 シリーズ2 の22日目です🎄
はじめに
このノートではJulia言語の自動微分パッケージ
を簡単に比べます.
パッケージ
同じ名前の関数が干渉するのでusing
ではなくimport
で読み込みます.
# import Pkg
# Pkg.add("ForwardDiff")
# Pkg.add("ReverseDiff")
# Pkg.add("Zygote")
# Pkg.add("Enzyme")
import ForwardDiff
import ReverseDifff
import Zygote
import Enzyme
例題
次の1変数関数
f(x) = x[1]^2
g(x) = x[1]^2 + 3*x[2]^2
ForwardDiff.jl
ForwardDiff.jlは読み込みが速いので個人的に気に入っています.
julia> ForwardDiff.derivative(f, 5)
10
julia> ForwardDiff.gradient(g, [5,5])
2-element Vector{Int64}:
10
30
ReverseDiff.jl
ReverseDiff.jlは使ったことがありませんが, Reverseモードの比較のために加えました. derivative
は存在しないようですので, 1変数でもgradient
を使います.
julia> ReverseDiff.gradient(f, [5])
1-element Vector{Int64}:
10
julia> ReverseDiff.gradient(g, [5,5])
2-element Vector{Int64}:
10
30
Zygote.jl
Zygote.jlはJulia in Physics 2021 OnlineにてSatoshi Terasakiさんによる講演で紹介されました. 単純な関数だとForwardDiff.jlの方が速いように思われますが, Wkipediaの解説の通り, 機械学習では効率的になるのかもしれません.
julia> Zygote.gradient(f, [5])[1]
1-element Vector{Float64}:
10.0
julia> Zygote.gradient(g, [5,5])[1]
2-element Vector{Float64}:
10.0
30.0
Enzyme.jl
Enzyme.jlは筑波大学 計算科学研究センター 第146回計算科学コロキウムとJulia in Physics 2024での講演のために来日して頂いたValentin Churavyさんが開発に携わっているパッケージです. 日本語の詳しい説明はJulia Advent Calendar 2024の永井さんの記事を参照して下さい. 他の言語でも使用できるとのことです. まだ理解できていませんが, ForwardモードはEnzyme.Duplicated(5.0, 1.0)
のように
julia> Enzyme.autodiff(Enzyme.Forward, f, Enzyme.Duplicated(5.0, 1.0))[1][1]
10.0
julia> Enzyme.autodiff(Enzyme.Reverse, f, Enzyme.Active(5.0))[1][1]
10.0
julia> Enzyme.gradient(Enzyme.Forward, f, 5.0)
10.0
julia> Enzyme.gradient(Enzyme.Reverse, f, 5.0)
10.0
julia> Enzyme.gradient(Enzyme.Forward, g, [5.0, 5.0])
(10.0, 30.0)
julia> Enzyme.gradient(Enzyme.Reverse, g, [5.0, 5.0])
2-element Vector{Float64}:
10.0
30.0
速度
先に注意点を述べておきます. Zygote.jlは引数をスカラーとして渡す場合とベクトルとして渡す場合で速度が大きく変わりますのでご注意ください.
using BenchmarkTools
@btime Zygote.gradient(x -> x^2, 5.0)[1]
@btime Zygote.gradient(x -> x[1]^2, [5.0])[1]
# 出力
1.200 ns (0 allocations: 0 bytes)
17.854 ns (1 allocation: 64 bytes)
1変数関数の場合は ForwardDiff.jl と Zygote.jl が圧倒的に速くなります.
@btime ForwardDiff.derivative(x -> x^2, 5.0)
@btime ForwardDiff.derivative(x -> x[1]^2, 5.0)
@btime ForwardDiff.gradient(x -> x[1]^2, [5.0])
@btime ReverseDiff.gradient(x -> x[1]^2, [5.0])
@btime Zygote.gradient(x -> x^2, 5.0)[1]
@btime Zygote.gradient(x -> x[1]^2, [5.0])[1]
@btime Enzyme.autodiff(Enzyme.Forward, x -> x[1]^2, Enzyme.Duplicated(5.0, 1.0))[1][1]
@btime Enzyme.autodiff(Enzyme.Reverse, x -> x[1]^2, Enzyme.Active(5.0))[1][1]
@btime Enzyme.gradient(Enzyme.Forward, x -> x[1]^2, 5.0)
@btime Enzyme.gradient(Enzyme.Reverse, x -> x[1]^2, 5.0)
# 出力
1.200 ns (0 allocations: 0 bytes)
1.200 ns (0 allocations: 0 bytes)
613.143 ns (5 allocations: 272 bytes)
469.388 ns (14 allocations: 688 bytes)
1.200 ns (0 allocations: 0 bytes)
17.990 ns (1 allocation: 64 bytes)
5.405 ns (0 allocations: 0 bytes)
5.400 ns (0 allocations: 0 bytes)
5.400 ns (0 allocations: 0 bytes)
7.808 ns (1 allocation: 16 bytes)
1変数関数のまま4項まで増やしたもので検証しておきますが, やはり最速は ForwardDiff.jl のderivative
かZygote.jl のスカラー引数のケースです.
@btime ForwardDiff.derivative(x -> x^2 + x^3 + x^4 + sin(x), 5.0)
@btime ForwardDiff.derivative(x -> x[1]^2 + x[1]^3 + x[1]^4 + sin(x[1]), 5.0)
@btime ForwardDiff.gradient(x -> x[1]^2 + x[1]^3 + x[1]^4 + sin(x[1]), [5.0])
@btime ReverseDiff.gradient(x -> x[1]^2 + x[1]^3 + x[1]^4 + sin(x[1]), [5.0])
@btime Zygote.gradient(x -> x^2 + x^3 + x^4 + sin(x), 5.0)[1]
@btime Zygote.gradient(x -> x[1]^2 + x[1]^3 + x[1]^4 + sin(x[1]), [5.0])[1]
@btime Enzyme.autodiff(Enzyme.Forward, x -> x[1]^2 + x[1]^3 + x[1]^4 + sin(x[1]), Enzyme.Duplicated(5.0, 1.0))[1][1]
@btime Enzyme.autodiff(Enzyme.Reverse, x -> x[1]^2 + x[1]^3 + x[1]^4 + sin(x[1]), Enzyme.Active(5.0))[1][1]
@btime Enzyme.gradient(Enzyme.Forward, x -> x[1]^2 + x[1]^3 + x[1]^4 + sin(x[1]), 5.0)
@btime Enzyme.gradient(Enzyme.Reverse, x -> x[1]^2 + x[1]^3 + x[1]^4 + sin(x[1]), 5.0)
# 出力
1.200 ns (0 allocations: 0 bytes)
1.200 ns (0 allocations: 0 bytes)
656.522 ns (5 allocations: 272 bytes)
1.050 μs (40 allocations: 1.69 KiB)
1.200 ns (0 allocations: 0 bytes)
1.400 μs (36 allocations: 2.25 KiB)
15.747 ns (0 allocations: 0 bytes)
15.230 ns (0 allocations: 0 bytes)
15.731 ns (0 allocations: 0 bytes)
18.236 ns (1 allocation: 16 bytes)
ニューラルネットワークを意識して合成関数でも検証しておきます. 結論は同じでした.
@btime ForwardDiff.derivative(x -> sin(0.1*sin(0.2*sin( 0.3*sin(2*x+1) + 0.3*sin(3*x+2) + 3) + 4) + 5), 5.0)
@btime ForwardDiff.derivative(x -> sin(0.1*sin(0.2*sin( 0.3*sin(2*x[1]+1) + 0.3*sin(3*x[1]+2) + 3) + 4) + 5), 5.0)
@btime ForwardDiff.gradient(x -> sin(0.1*sin(0.2*sin( 0.3*sin(2*x[1]+1) + 0.3*sin(3*x[1]+2) + 3) + 4) + 5), [5.0])
@btime ReverseDiff.gradient(x -> sin(0.1*sin(0.2*sin( 0.3*sin(2*x[1]+1) + 0.3*sin(3*x[1]+2) + 3) + 4) + 5), [5.0])
@btime Zygote.gradient(x -> sin(0.1*sin(0.2*sin( 0.3*sin(2*x+1) + 0.3*sin(3*x+2) + 3) + 4) + 5), 5.0)[1]
@btime Zygote.gradient(x -> sin(0.1*sin(0.2*sin( 0.3*sin(2*x[1]+1) + 0.3*sin(3*x[1]+2) + 3) + 4) + 5), [5.0])[1]
@btime Enzyme.autodiff(Enzyme.Forward, x -> sin(0.1*sin(0.2*sin( 0.3*sin(2*x[1]+1) + 0.3*sin(3*x[1]+2) + 3) + 4) + 5), Enzyme.Duplicated(5.0, 1.0))[1][1]
@btime Enzyme.autodiff(Enzyme.Reverse, x -> sin(0.1*sin(0.2*sin( 0.3*sin(2*x[1]+1) + 0.3*sin(3*x[1]+2) + 3) + 4) + 5), Enzyme.Active(5.0))[1][1]
@btime Enzyme.gradient(Enzyme.Forward, x -> sin(0.1*sin(0.2*sin( 0.3*sin(2*x[1]+1) + 0.3*sin(3*x[1]+2) + 3) + 4) + 5), 5.0)
@btime Enzyme.gradient(Enzyme.Reverse, x -> sin(0.1*sin(0.2*sin( 0.3*sin(2*x[1]+1) + 0.3*sin(3*x[1]+2) + 3) + 4) + 5), 5.0)
# 出力
1.200 ns (0 allocations: 0 bytes)
1.200 ns (0 allocations: 0 bytes)
627.168 ns (5 allocations: 272 bytes)
2.378 μs (79 allocations: 3.30 KiB)
1.200 ns (0 allocations: 0 bytes)
633.103 ns (19 allocations: 2.66 KiB)
61.162 ns (0 allocations: 0 bytes)
71.971 ns (0 allocations: 0 bytes)
61.100 ns (0 allocations: 0 bytes)
74.253 ns (1 allocation: 16 bytes)
2変数関数の場合は Enzyme.jl のReverseモードが圧倒的に速いようです. 同じReverseモードでもReverseDiff.gradient
とZygote.gradient
とは大きく異なります.
@btime ForwardDiff.gradient(g, [5.0, 5.0])
@btime ReverseDiff.gradient(g, [5.0, 5.0])
@btime Zygote.gradient(g, [5.0, 5.0])[1]
@btime Enzyme.gradient(Enzyme.Forward, g, [5.0, 5.0])
@btime Enzyme.gradient(Enzyme.Reverse, g, [5.0, 5.0])
# 出力
573.770 ns (5 allocations: 368 bytes)
801.176 ns (27 allocations: 1.25 KiB)
498.446 ns (19 allocations: 1.14 KiB)
581.215 ns (9 allocations: 432 bytes)
42.339 ns (2 allocations: 160 bytes)
変数を増やして4変数にしたもので検証しておきます. 結論は変わりません.
@btime ForwardDiff.gradient(x -> x[1]^2 + x[2]^3 + x[3]^4 + sin(x[4]), [5.0, 5.0, 5.0, 5.0])
@btime ReverseDiff.gradient(x -> x[1]^2 + x[2]^3 + x[3]^4 + sin(x[4]), [5.0, 5.0, 5.0, 5.0])
@btime Zygote.gradient(x -> x[1]^2 + x[2]^3 + x[3]^4 + sin(x[4]), [5.0, 5.0, 5.0, 5.0])[1]
@btime Enzyme.gradient(Enzyme.Forward, x -> x[1]^2 + x[2]^3 + x[3]^4 + sin(x[4]), [5.0, 5.0, 5.0, 5.0])
@btime Enzyme.gradient(Enzyme.Reverse, x -> x[1]^2 + x[2]^3 + x[3]^4 + sin(x[4]), [5.0, 5.0, 5.0, 5.0])
# 出力
661.585 ns (5 allocations: 704 bytes)
1.050 μs (40 allocations: 1.81 KiB)
1.410 μs (36 allocations: 2.31 KiB)
593.855 ns (11 allocations: 752 bytes)
54.158 ns (2 allocations: 192 bytes)
まとめ
1変数関数の場合は
-
ForwardDiff.jl の
gradient
ではなくderivative
-
Zygote.jl の
gradient
で引数をベクトルではなくスカラーとして渡す
多変数関数の場合は
-
Enzyme.jlの
gradient
をReverseモードで使う
がおすすめです. ただし, 1変数関数の場合でもEnzyme.jlはLLVMベースのため他の言語でも使用できるなど, 速度以外のアドバンテージがあることに注意してください.
バージョン情報
Julia v1.10.5
ForwardDiff v0.10.36
Zygote v0.6.73
Enzyme v0.12.36
Discussion