Juliaの自動微分を簡単に比べてみる
Julia Advent Calendar 2024 シリーズ2 の22日目です🎄
Juliaの自動微分を簡単に比べてみる
この記事ではJulia言語の自動微分パッケージ
を簡単に比べます.
パッケージ
同じ名前の関数が干渉するのでusing
ではなくimport
で読み込みます.
# import Pkg
# Pkg.add("ForwardDiff")
# Pkg.add("Zygote")
# Pkg.add("Enzyme")
import ForwardDiff
import Zygote
import Enzyme
例題
次の1変数関数
f(x) = x^2
g(x) = x[1]^2 + 3*x[2]^2
ForwardDiff.jl
ForwardDiff.jlは読み込みが速いので個人的に気に入っています.
julia> ForwardDiff.derivative(f, 5)
10
julia> ForwardDiff.gradient(g, [5,5])
2-element Vector{Int64}:
10
30
Zygote.jl
Zygote.jlはJulia in Physics 2021 OnlineにてSatoshi Terasakiさんによる講演で紹介されました. 単純な関数だとForwardDiff.jlの方が速いように思われますが, Wkipediaの解説の通り, 機械学習では効率的になるのかもしれません.
julia> Zygote.gradient(f, 5)[1]
10.0
julia> Zygote.gradient(g, [5,5])[1]
2-element Vector{Float64}:
10.0
30.0
Enzyme.jl
Enzyme.jlは筑波大学 計算科学研究センター 第146回計算科学コロキウムとJulia in Physics 2024での講演のために来日して頂いたValentin Churavyさんが開発に携わっているパッケージです. 日本語の詳しい説明はJulia Advent Calendar 2024の永井さんの記事を参照して下さい. 他の言語でも使用できるとのことです.
julia> Enzyme.autodiff(Enzyme.Reverse, f, Enzyme.Active(5.0))[1][1]
10.0
julia> Enzyme.gradient(Enzyme.Reverse, f, 5.0)
10.0
julia> Enzyme.gradient(Enzyme.Reverse, g, [5.0, 5.0])
2-element Vector{Float64}:
10.0
30.0
速度
面白い結果が出ました. 1変数関数の場合は ForwardDiff.jl と Zygote.jl の方が速く, 2変数関数の場合は Enzyme.jl の方が速いようです.
using BenchmarkTools
@btime ForwardDiff.derivative(f, 5.0)
@btime Zygote.gradient(f, 5.0)[1]
@btime Enzyme.autodiff(Enzyme.Reverse, f, Enzyme.Active(5.0))[1][1]
@btime Enzyme.gradient(Enzyme.Reverse, f, 5.0)
# 出力
1.200 ns (0 allocations: 0 bytes)
1.200 ns (0 allocations: 0 bytes)
5.400 ns (0 allocations: 0 bytes)
9.419 ns (1 allocation: 16 bytes)
@btime ForwardDiff.gradient(g, [5.0, 5.0])
@btime Zygote.gradient(g, [5.0, 5.0])[1]
@btime Enzyme.gradient(Enzyme.Reverse, g, [5.0, 5.0])
# 出力
580.337 ns (5 allocations: 368 bytes)
501.117 ns (19 allocations: 1.14 KiB)
42.281 ns (2 allocations: 160 bytes)
一応, 1変数関数で項を増やしたものと, 変数を増やして4変数にしたもので検証しておきます.
@btime ForwardDiff.derivative(x -> x^2 + x^3 + x^4 + sin(x), 5)
@btime Zygote.gradient(x -> x^2 + x^3 + x^4 + sin(x), 5)[1]
@btime Enzyme.autodiff(Enzyme.Reverse, x -> x^2 + x^3 + x^4 + sin(x), Enzyme.Active(5.0))[1][1]
@btime Enzyme.gradient(Enzyme.Reverse, x -> x^2 + x^3 + x^4 + sin(x), 5.0)
# 出力
1.200 ns (0 allocations: 0 bytes)
1.200 ns (0 allocations: 0 bytes)
15.230 ns (0 allocations: 0 bytes)
17.936 ns (1 allocation: 16 bytes)
@btime ForwardDiff.gradient(x -> x[1]^2 + x[2]^3 + x[3]^4 + sin(x[4]), [5.0, 5.0, 5.0, 5.0])
@btime Zygote.gradient(x -> x[1]^2 + x[2]^3 + x[3]^4 + sin(x[4]), [5.0, 5.0, 5.0, 5.0])[1]
@btime Enzyme.gradient(Enzyme.Reverse, x -> x[1]^2 + x[2]^3 + x[3]^4 + sin(x[4]), [5.0, 5.0, 5.0, 5.0])
# 出力
607.955 ns (5 allocations: 704 bytes)
1.400 μs (36 allocations: 2.31 KiB)
54.010 ns (2 allocations: 192 bytes)
まとめ
1変数関数の場合は
多変数関数の場合は
がおすすめです. ただし, Enzyme.jlはLLVMベースのため他の言語でも使用できるなど, 速度以外のアドバンテージがあることに注意してください.
バージョン情報
Julia v1.10.5
ForwardDiff v0.10.36
Zygote v0.6.73
Enzyme v0.12.36
Discussion