⚖️

Juliaの自動微分を簡単に比べてみる

2024/12/22に公開

Julia Advent Calendar 2024 シリーズ2 の22日目です🎄
https://qiita.com/advent-calendar/2024/julia

はじめに

このノートではJulia言語の自動微分パッケージ

を簡単に比べます.

パッケージ

同じ名前の関数が干渉するのでusingではなくimportで読み込みます.

パッケージの読み込み
# import Pkg
# Pkg.add("ForwardDiff")
# Pkg.add("ReverseDiff")
# Pkg.add("Zygote")
# Pkg.add("Enzyme")
import ForwardDiff
import ReverseDifff
import Zygote
import Enzyme

例題

次の1変数関数 f(x) の微分係数 \frac{\mathrm{d}f}{\mathrm{d}x}(5)=10 と2変数関数 g(x_1, x_2) の勾配 \nabla g(5,5) = (10,30) を求めます.

f(x) = x^2
g(x_1, x_2) = {x_1}^2 + 3{x_2}^2
1変数関数の例
f(x) = x[1]^2
2変数関数の例
g(x) = x[1]^2 + 3*x[2]^2

ForwardDiff.jl

ForwardDiff.jlは読み込みが速いので個人的に気に入っています.

1変数関数の例
julia> ForwardDiff.derivative(f, 5)
10
2変数関数の例
julia> ForwardDiff.gradient(g, [5,5])
2-element Vector{Int64}:
 10
 30

ReverseDiff.jl

ReverseDiff.jlは使ったことがありませんが, Reverseモードの比較のために加えました. derivativeは存在しないようですので, 1変数でもgradientを使います.

1変数関数の例
julia> ReverseDiff.gradient(f, [5])
1-element Vector{Int64}:
 10
2変数関数の例
julia> ReverseDiff.gradient(g, [5,5])
2-element Vector{Int64}:
 10
 30

Zygote.jl

Zygote.jlJulia in Physics 2021 OnlineにてSatoshi Terasakiさんによる講演で紹介されました. 単純な関数だとForwardDiff.jlの方が速いように思われますが, Wkipediaの解説の通り, 機械学習では効率的になるのかもしれません.

1変数関数の例
julia> Zygote.gradient(f, [5])[1]
1-element Vector{Float64}:
 10.0
2変数関数の例
julia> Zygote.gradient(g, [5,5])[1]
2-element Vector{Float64}:
 10.0
 30.0

Enzyme.jl

Enzyme.jl筑波大学 計算科学研究センター 第146回計算科学コロキウムJulia in Physics 2024での講演のために来日して頂いたValentin Churavyさんが開発に携わっているパッケージです. 日本語の詳しい説明はJulia Advent Calendar 2024永井さんの記事を参照して下さい. 他の言語でも使用できるとのことです. まだ理解できていませんが, ForwardモードEnzyme.Duplicated(5.0, 1.0)のように x\dot{x} を渡す必要があります.

1変数関数の例, Forwardモード
julia> Enzyme.autodiff(Enzyme.Forward, f, Enzyme.Duplicated(5.0, 1.0))[1][1]
10.0
1変数関数の例, Reverseモード
julia> Enzyme.autodiff(Enzyme.Reverse, f, Enzyme.Active(5.0))[1][1]
10.0
1変数関数の例, gradient, Forwardモード
julia> Enzyme.gradient(Enzyme.Forward, f, 5.0)
10.0
1変数関数の例, gradient, Forwardモード
julia> Enzyme.gradient(Enzyme.Reverse, f, 5.0)
10.0
2変数関数の例, gradient, Forwardモード
julia> Enzyme.gradient(Enzyme.Forward, g, [5.0, 5.0])
(10.0, 30.0)
2変数関数の例, gradient, Reverseモード
julia> Enzyme.gradient(Enzyme.Reverse, g, [5.0, 5.0])
2-element Vector{Float64}:
 10.0
 30.0

速度

先に注意点を述べておきます. Zygote.jlは引数をスカラーとして渡す場合とベクトルとして渡す場合で速度が大きく変わりますのでご注意ください.

Zygote.jlの注意点
using BenchmarkTools
@btime Zygote.gradient(x -> x^2, 5.0)[1]
@btime Zygote.gradient(x -> x[1]^2, [5.0])[1]

# 出力
  1.200 ns (0 allocations: 0 bytes)
  17.854 ns (1 allocation: 64 bytes)

1変数関数の場合は ForwardDiff.jlZygote.jl が圧倒的に速くなります.

1変数関数の例
@btime ForwardDiff.derivative(x -> x^2, 5.0)
@btime ForwardDiff.derivative(x -> x[1]^2, 5.0)
@btime ForwardDiff.gradient(x -> x[1]^2, [5.0])
@btime ReverseDiff.gradient(x -> x[1]^2, [5.0])
@btime Zygote.gradient(x -> x^2, 5.0)[1]
@btime Zygote.gradient(x -> x[1]^2, [5.0])[1]
@btime Enzyme.autodiff(Enzyme.Forward, x -> x[1]^2, Enzyme.Duplicated(5.0, 1.0))[1][1]
@btime Enzyme.autodiff(Enzyme.Reverse, x -> x[1]^2, Enzyme.Active(5.0))[1][1]
@btime Enzyme.gradient(Enzyme.Forward, x -> x[1]^2, 5.0)
@btime Enzyme.gradient(Enzyme.Reverse, x -> x[1]^2, 5.0)

# 出力
  1.200 ns (0 allocations: 0 bytes)
  1.200 ns (0 allocations: 0 bytes)
  613.143 ns (5 allocations: 272 bytes)
  469.388 ns (14 allocations: 688 bytes)
  1.200 ns (0 allocations: 0 bytes)
  17.990 ns (1 allocation: 64 bytes)
  5.405 ns (0 allocations: 0 bytes)
  5.400 ns (0 allocations: 0 bytes)
  5.400 ns (0 allocations: 0 bytes)
  7.808 ns (1 allocation: 16 bytes)

1変数関数のまま4項まで増やしたもので検証しておきますが, やはり最速は ForwardDiff.jlderivativeZygote.jl のスカラー引数のケースです.

1変数関数の例, 4項まで増やした場合
@btime ForwardDiff.derivative(x -> x^2 + x^3 + x^4 + sin(x), 5.0)
@btime ForwardDiff.derivative(x -> x[1]^2 + x[1]^3 + x[1]^4 + sin(x[1]), 5.0)
@btime ForwardDiff.gradient(x -> x[1]^2 + x[1]^3 + x[1]^4 + sin(x[1]), [5.0])
@btime ReverseDiff.gradient(x -> x[1]^2 + x[1]^3 + x[1]^4 + sin(x[1]), [5.0])
@btime Zygote.gradient(x -> x^2 + x^3 + x^4 + sin(x), 5.0)[1]
@btime Zygote.gradient(x -> x[1]^2 + x[1]^3 + x[1]^4 + sin(x[1]), [5.0])[1]
@btime Enzyme.autodiff(Enzyme.Forward, x -> x[1]^2 + x[1]^3 + x[1]^4 + sin(x[1]), Enzyme.Duplicated(5.0, 1.0))[1][1]
@btime Enzyme.autodiff(Enzyme.Reverse, x -> x[1]^2 + x[1]^3 + x[1]^4 + sin(x[1]), Enzyme.Active(5.0))[1][1]
@btime Enzyme.gradient(Enzyme.Forward, x -> x[1]^2 + x[1]^3 + x[1]^4 + sin(x[1]), 5.0)
@btime Enzyme.gradient(Enzyme.Reverse, x -> x[1]^2 + x[1]^3 + x[1]^4 + sin(x[1]), 5.0)

# 出力
  1.200 ns (0 allocations: 0 bytes)
  1.200 ns (0 allocations: 0 bytes)
  656.522 ns (5 allocations: 272 bytes)
  1.050 μs (40 allocations: 1.69 KiB)
  1.200 ns (0 allocations: 0 bytes)
  1.400 μs (36 allocations: 2.25 KiB)
  15.747 ns (0 allocations: 0 bytes)
  15.230 ns (0 allocations: 0 bytes)
  15.731 ns (0 allocations: 0 bytes)
  18.236 ns (1 allocation: 16 bytes)

ニューラルネットワークを意識して合成関数でも検証しておきます. 結論は同じでした.

1変数関数の例, 4重の合成関数の場合
@btime ForwardDiff.derivative(x -> sin(0.1*sin(0.2*sin( 0.3*sin(2*x+1) + 0.3*sin(3*x+2) + 3) + 4) + 5), 5.0)
@btime ForwardDiff.derivative(x -> sin(0.1*sin(0.2*sin( 0.3*sin(2*x[1]+1) + 0.3*sin(3*x[1]+2) + 3) + 4) + 5), 5.0)
@btime ForwardDiff.gradient(x -> sin(0.1*sin(0.2*sin( 0.3*sin(2*x[1]+1) + 0.3*sin(3*x[1]+2) + 3) + 4) + 5), [5.0])
@btime ReverseDiff.gradient(x -> sin(0.1*sin(0.2*sin( 0.3*sin(2*x[1]+1) + 0.3*sin(3*x[1]+2) + 3) + 4) + 5), [5.0])
@btime Zygote.gradient(x -> sin(0.1*sin(0.2*sin( 0.3*sin(2*x+1) + 0.3*sin(3*x+2) + 3) + 4) + 5), 5.0)[1]
@btime Zygote.gradient(x -> sin(0.1*sin(0.2*sin( 0.3*sin(2*x[1]+1) + 0.3*sin(3*x[1]+2) + 3) + 4) + 5), [5.0])[1]
@btime Enzyme.autodiff(Enzyme.Forward, x -> sin(0.1*sin(0.2*sin( 0.3*sin(2*x[1]+1) + 0.3*sin(3*x[1]+2) + 3) + 4) + 5), Enzyme.Duplicated(5.0, 1.0))[1][1]
@btime Enzyme.autodiff(Enzyme.Reverse, x -> sin(0.1*sin(0.2*sin( 0.3*sin(2*x[1]+1) + 0.3*sin(3*x[1]+2) + 3) + 4) + 5), Enzyme.Active(5.0))[1][1]
@btime Enzyme.gradient(Enzyme.Forward, x -> sin(0.1*sin(0.2*sin( 0.3*sin(2*x[1]+1) + 0.3*sin(3*x[1]+2) + 3) + 4) + 5), 5.0)
@btime Enzyme.gradient(Enzyme.Reverse, x -> sin(0.1*sin(0.2*sin( 0.3*sin(2*x[1]+1) + 0.3*sin(3*x[1]+2) + 3) + 4) + 5), 5.0)

# 出力
  1.200 ns (0 allocations: 0 bytes)
  1.200 ns (0 allocations: 0 bytes)
  627.168 ns (5 allocations: 272 bytes)
  2.378 μs (79 allocations: 3.30 KiB)
  1.200 ns (0 allocations: 0 bytes)
  633.103 ns (19 allocations: 2.66 KiB)
  61.162 ns (0 allocations: 0 bytes)
  71.971 ns (0 allocations: 0 bytes)
  61.100 ns (0 allocations: 0 bytes)
  74.253 ns (1 allocation: 16 bytes)

2変数関数の場合は Enzyme.jl のReverseモードが圧倒的に速いようです. 同じReverseモードでもReverseDiff.gradientZygote.gradientとは大きく異なります.

2変数関数の例
@btime ForwardDiff.gradient(g, [5.0, 5.0])
@btime ReverseDiff.gradient(g, [5.0, 5.0])
@btime Zygote.gradient(g, [5.0, 5.0])[1]
@btime Enzyme.gradient(Enzyme.Forward, g, [5.0, 5.0])
@btime Enzyme.gradient(Enzyme.Reverse, g, [5.0, 5.0])

# 出力
  573.770 ns (5 allocations: 368 bytes)
  801.176 ns (27 allocations: 1.25 KiB)
  498.446 ns (19 allocations: 1.14 KiB)
  581.215 ns (9 allocations: 432 bytes)
  42.339 ns (2 allocations: 160 bytes)

変数を増やして4変数にしたもので検証しておきます. 結論は変わりません.

4変数関数の例
@btime ForwardDiff.gradient(x -> x[1]^2 + x[2]^3 + x[3]^4 + sin(x[4]), [5.0, 5.0, 5.0, 5.0])
@btime ReverseDiff.gradient(x -> x[1]^2 + x[2]^3 + x[3]^4 + sin(x[4]), [5.0, 5.0, 5.0, 5.0])
@btime Zygote.gradient(x -> x[1]^2 + x[2]^3 + x[3]^4 + sin(x[4]), [5.0, 5.0, 5.0, 5.0])[1]
@btime Enzyme.gradient(Enzyme.Forward, x -> x[1]^2 + x[2]^3 + x[3]^4 + sin(x[4]), [5.0, 5.0, 5.0, 5.0])
@btime Enzyme.gradient(Enzyme.Reverse, x -> x[1]^2 + x[2]^3 + x[3]^4 + sin(x[4]), [5.0, 5.0, 5.0, 5.0])

# 出力
  661.585 ns (5 allocations: 704 bytes)
  1.050 μs (40 allocations: 1.81 KiB)
  1.410 μs (36 allocations: 2.31 KiB)
  593.855 ns (11 allocations: 752 bytes)
  54.158 ns (2 allocations: 192 bytes)

まとめ

1変数関数の場合は

  • ForwardDiff.jlgradientではなくderivative
  • Zygote.jlgradientで引数をベクトルではなくスカラーとして渡す

多変数関数の場合は

  • Enzyme.jlgradientをReverseモードで使う

がおすすめです. ただし, 1変数関数の場合でもEnzyme.jlはLLVMベースのため他の言語でも使用できるなど, 速度以外のアドバンテージがあることに注意してください.

バージョン情報

Julia v1.10.5
ForwardDiff v0.10.36
Zygote v0.6.73
Enzyme v0.12.36

https://gist.github.com/ohno/24126924230b769336038f3ebe1905a8

Discussion