⚖️

Flux.jlとLux.jlを簡単に比べてみる

2025/01/13に公開

Julia Advent Calendar 2024 シリーズ2 の17日目です🎄
https://qiita.com/advent-calendar/2024/julia

はじめに

まず, Julia言語の機械学習パッケージであるFlux.jlからLux.jlへの移行手順を解説します. 次に最小二乗法を題材として同じ計算結果を再現するコードを作成し, コードの実行時間を比較します. 永井さんの記事ではこれらのパッケージのコンセプトの違いなどが解説されていますので併せてご覧ください.

パッケージ

下記のパッケージをインストールしておきます. 以降は, 同じ名前の関数が干渉するのでusingではなくimportで読み込みます.

パッケージのインストール
import Pkg
Pkg.add("Flux")
Pkg.add("Lux")
Pkg.add("Zygote")
Pkg.add("Optimisers")
Pkg.add("CairoMakie")
Pkg.add("BenchmarkTools")

Flux.jlからLux.jlへの移行手順

Lux.jlのドキュメントに解説があります. これに従って書き換えていきます.

Flux.jlからLux.jlへの移行手順
- import Flux, Random, Zygote
+ import  Lux, Random, Zygote

- flux_model = Flux.Chain(Flux.Dense(2 => 4, Flux.relu), Flux.Dense(4 => 2, Flux.relu))
+  lux_model =  Lux.Chain( Lux.Dense(2 => 4,  Lux.relu),  Lux.Dense(4 => 2,  Lux.relu))

- rng = Random.MersenneTwister(123)
- x = Random.randn(rng, Float32, 2, 4)
- value = flux_model(x)
- # 2×4 Matrix{Float32}:
- #  0.777463  0.297754  1.2554  0.360352
- #  0.134753  0.0       0.0     0.0
+ ps, st = Lux.setup(rng, lux_model)
+ value = first(lux_model(x, ps, st))
+ # 2×4 Matrix{Float32}:
+ #  0.359655  0.0         0.0       0.0
+ #  1.81674   0.00757596  0.306529  0.0

- Zygote.gradient(flux_model -> sum(flux_model(x)), flux_model)
+ Zygote.gradient(ps -> sum(first(lux_model(x, ps, st))), ps)

標準の機能でFlux.jlと完全に同じ初期値にする方法はわかりませんでしたが, 次のようにパラメータを名前付きタプルで定義すれば同じ初期値を使うことができます.

パラメータを揃える方法
ps, st = Lux.setup(rng, lux_model)
ps = (
    layer_1 = (weight = copy(flux_model[1].weight), bias = copy(flux_model[1].bias)),
    layer_2 = (weight = copy(flux_model[2].weight), bias = copy(flux_model[2].bias)),
)
value = first(lux_model(x, ps, st))
# 2×4 Matrix{Float32}:
#  0.777463  0.297754  1.2554  0.360352
#  0.134753  0.0       0.0     0.0

Flux.jlによる最小二乗法の実装例

最小二乗法を用いて, 離散化された放物線 (コード内のXY) を近似してみましょう. 1000エポックごとに損失 (平均二乗誤差) を出力します. これは0に近いほど学習データに近いことを意味します. 後ほどLux.jlで同じ計算結果を再現するために初期値を保存しておきます.

Flux.jlによる最小二乗法の実装例
import Flux

# ニューラルネットワークの定義, 初期化
MT = Flux.MersenneTwister(1234)
flux_model = Flux.Chain(
    Flux.Dense(1, 9, Flux.relu, init=Flux.randn32(MT)),
    Flux.Dense(9, 9, Flux.relu, init=Flux.randn32(MT)),
    Flux.Dense(9, 1, Flux.relu, init=Flux.randn32(MT)),
)

# 初期値の保存, 表示
initial_parameters = deepcopy(Flux.state(flux_model))
@show initial_parameters

# 学習データの定義
X = Float32.(-10:1.0:10)'
Y = X .^ 2

# 損失関数の定義
flux_loss_function(model, X) = Flux.Losses.mse(model(X), Y)

# オプティマイザの定義
flux_optimizer = Flux.setup(Flux.Adam(), flux_model)

# 学習, 途中経過の表示
println("\ni   \tloss")
println("-----\t----------")
@time for i in 0:15000
    if rem(i,1000) == 0
        loss = flux_loss_function(flux_model, X)
        println("$i\t$loss")
    end
    Flux.train!(flux_loss_function, flux_model, (X,), flux_optimizer)
end

# 学習済みモデル
f(x) = sum(flux_model([x]))

# 描写
using CairoMakie
fig = Figure(size=(420,300), fontsize=11.5, backgroundcolor=:transparent)
axis = Axis(fig[1,1], xlabel=L"$x$", ylabel=L"$y$", xlabelsize=16.5, ylabelsize=16.5)
lines!(axis, -10..10, x -> f(x), label="model")
scatter!(axis, X', x -> x^2, color=:black, label="exact")
axislegend(axis, position=:rb, framevisible=false)
fig
出力
initial_parameters = (layers = ((weight = Float32[-0.81195045; 2.5622392; 0.10329163; -0.21205132; 2.5562384; 0.0104885455; -0.6997176; -0.83282274; -1.4579619;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], σ = ()), (weight = Float32[-0.8652357 1.7772343 0.23964548 0.25175205 -0.71538484 0.353435 -0.71898866 -0.03159868 1.5777229; 0.7690479 -2.7981696 -0.9601986 -0.8235361 0.3812632 0.25984508 0.03140855 -2.247552 -1.876675; 1.516688 0.12623774 -1.7742386 0.91808116 -2.5331562 -0.20585492 -0.970512 -0.2074294 -0.67983174; -1.7946341 0.06827301 1.1700256 1.0815432 -1.4005985 0.090843394 -0.83741915 -1.9605532 -0.24460198; -0.9856712 0.8465971 -0.7440299 0.14951882 1.0899436 0.81333387 1.3086019 0.60052055 -1.7192795; 0.47419593 0.21375325 0.54072803 -1.2877525 -1.0776865 -0.5000487 -0.46481803 -2.956821 -1.757924; -1.0986273 -0.12919612 0.23816317 -1.298389 0.2482856 -1.7409434 1.4281609 0.718066 0.47351286; 0.67747116 0.51623446 0.12625434 0.29146492 -0.10786364 1.1969564 1.8717147 -1.3031343 -0.31237987; -1.3218207 -0.48563242 0.20035145 -0.24979854 -0.3571935 0.96124464 1.8206987 0.45036435 0.6091286], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], σ = ()), (weight = Float32[1.0364083 -0.2783158 2.2698267 -2.0273461 0.20790057 -1.8240385 0.28225797 2.6128943 -1.0510497], bias = Float32[0.0], σ = ())),)

i   	loss
-----	----------
0	1035.4657
1000	63.08293
2000	20.24789
3000	6.4580956
4000	1.8206335
5000	0.8585495
6000	0.59937143
7000	0.38066334
8000	0.22740911
9000	0.16608828
10000	0.15242149
11000	0.14977677
12000	0.14702067
13000	0.14498036
14000	0.14385651
15000	0.14313066
  6.293734 seconds (16.63 M allocations: 1.071 GiB, 7.70% gc time, 80.45% compilation time)

Lux.jlによる最小二乗法の実装例

上記のFlux.jlのコードをLux.jlに移植します. かなり頑張りましたがFulx.jlと同じ初期値を再現する方法が見つからなかったので, 先ほど保存しておいた初期値を利用しています. 数値誤差の範囲でFlux.jlと同じ結果が得られました. 計算時間はLux.jlの方が短いようです.

Lux.jlによる最小二乗法の実装例
import Lux
import Random
import Optimisers

# ニューラルネットワークの定義, 初期化
lux_model = Lux.Chain(
    Lux.Dense(1 => 9, Lux.relu),
    Lux.Dense(9 => 9, Lux.relu),
    Lux.Dense(9 => 1, Lux.relu),
)
MT = Random.MersenneTwister(1234)
parameters, state = Lux.setup(MT, lux_model)

# Flux.jlの初期値の読み込み, 表示
parameters = (
    layer_1 = (weight = initial_parameters[1][1].weight, bias = initial_parameters[1][1].bias),
    layer_2 = (weight = initial_parameters[1][2].weight, bias = initial_parameters[1][2].bias),
    layer_3 = (weight = initial_parameters[1][3].weight, bias = initial_parameters[1][3].bias),
)
@show parameters

# 学習データの定義
X = Float32.(-10:1.0:10)' # ここに転置がかかっていることに注意
Y = X .^ 2

# 損失関数の定義
lux_loss_function = Lux.MSELoss()

# オプティマイザの定義
lux_optimizer = Optimisers.Adam()

# 学習, 途中経過の表示
println("\ni   \tloss")
println("-----\t----------")
tstate = Lux.Training.TrainState(lux_model, parameters, state, lux_optimizer)
@time for i in 0:15000
    grads, loss, _, tstate = Lux.Training.single_train_step!(Lux.AutoZygote(), lux_loss_function, (X,Y), tstate)
    if rem(i,1000) == 0
        println("$i\t$loss")
    end
end

# 学習済みモデル
model = tstate.model
parameters = tstate.parameters
states = tstate.states
f(x) = sum(first(model([x], parameters, states)))

# 描写
using CairoMakie
fig = Figure(size=(420,300), fontsize=11.5, backgroundcolor=:transparent)
axis = Axis(fig[1,1], xlabel=L"$x$", ylabel=L"$y$", xlabelsize=16.5, ylabelsize=16.5)
lines!(axis, -10..10, x -> f(x), label="model")
scatter!(axis, X', x -> x^2, color=:black, label="exact")
axislegend(axis, position=:rb, framevisible=false)
fig
出力
parameters = (layer_1 = (weight = Float32[-0.81195045; 2.5622392; 0.10329163; -0.21205132; 2.5562384; 0.0104885455; -0.6997176; -0.83282274; -1.4579619;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_2 = (weight = Float32[-0.8652357 1.7772343 0.23964548 0.25175205 -0.71538484 0.353435 -0.71898866 -0.03159868 1.5777229; 0.7690479 -2.7981696 -0.9601986 -0.8235361 0.3812632 0.25984508 0.03140855 -2.247552 -1.876675; 1.516688 0.12623774 -1.7742386 0.91808116 -2.5331562 -0.20585492 -0.970512 -0.2074294 -0.67983174; -1.7946341 0.06827301 1.1700256 1.0815432 -1.4005985 0.090843394 -0.83741915 -1.9605532 -0.24460198; -0.9856712 0.8465971 -0.7440299 0.14951882 1.0899436 0.81333387 1.3086019 0.60052055 -1.7192795; 0.47419593 0.21375325 0.54072803 -1.2877525 -1.0776865 -0.5000487 -0.46481803 -2.956821 -1.757924; -1.0986273 -0.12919612 0.23816317 -1.298389 0.2482856 -1.7409434 1.4281609 0.718066 0.47351286; 0.67747116 0.51623446 0.12625434 0.29146492 -0.10786364 1.1969564 1.8717147 -1.3031343 -0.31237987; -1.3218207 -0.48563242 0.20035145 -0.24979854 -0.3571935 0.96124464 1.8206987 0.45036435 0.6091286], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[1.0364083 -0.2783158 2.2698267 -2.0273461 0.20790057 -1.8240385 0.28225797 2.6128943 -1.0510497], bias = Float32[0.0]))

i   	loss
-----	----------
0	1035.4657
1000	63.08293
2000	20.24789
3000	6.4580956
4000	1.8206335
5000	0.8585495
6000	0.5993715
7000	0.38066337
8000	0.22740911
9000	0.16608828
10000	0.15242149
11000	0.14977676
12000	0.14702067
13000	0.14498036
14000	0.14385653
15000	0.14313067
  3.357067 seconds (8.27 M allocations: 484.290 MiB, 4.43% gc time, 85.54% compilation time)

計算時間の違いはどこから?

Lux.jlが速い理由を考察します. 損失関数の計算時間, 損失関数のパラメータに対する勾配の計算時間, 1エポックあたりの計算時間をそれぞれ測定しました. 損失関数の計算時間はほとんど同じですが, 勾配の計算時間はLux.jlの方が短いです. 学習エポックは勾配を用いて損失関数を最小化するステップなので, 必然的にLux.jlの方が速くなります. Lux.jlの方が勾配の計算が速い理由はよくわかりませんが, Lux.jlのレイヤーが純粋関数として定義されているため微分しやすいのではないかと思われます.

計算速度の比較
import BenchmarkTools

# 損失関数
BenchmarkTools.@btime flux_loss_function(flux_model, X)
BenchmarkTools.@btime lux_loss_function(first(lux_model(X,parameters,state)), Y)
#  736.508 ns (9 allocations: 2.02 KiB)
#  727.132 ns (14 allocations: 2.16 KiB)

# パラメータ勾配
BenchmarkTools.@btime Zygote.gradient(flux_model -> flux_loss_function(flux_model, X), flux_model)
BenchmarkTools.@btime Zygote.gradient(parameters -> lux_loss_function(first(lux_model(X,parameters,state)), Y), parameters)
#  22.500 μs (129 allocations: 24.44 KiB)
#  8.000 μs (91 allocations: 25.53 KiB)

# 1エポックあたりの計算時間
BenchmarkTools.@btime Flux.train!(flux_loss_function, flux_model, (X,), flux_optimizer)
BenchmarkTools.@btime Lux.Training.single_train_step!(Lux.AutoZygote(), lux_loss_function, (X,Y), tstate)
#  29.900 μs (305 allocations: 33.64 KiB)
#  11.400 μs (150 allocations: 12.36 KiB)

まとめ

Flux.jlからLux.jlへの移行手順を解説しました. 標準の機能で初期値を揃える方法はわかりませんでしたが, 手動で初期値を揃えて計算したところ, 数値誤差の範囲で一致する計算結果が得られました. 計算はLux.jlの方が高速でした. これはLux.jlの方が勾配の計算が速いためです.

バージョン情報

バージョン情報
Julia v1.11.2
Flux v0.16.0
Lux v1.4.4
Zygote v0.6.75
Optimisers v0.4.4
CairoMakie v0.12.18
BenchmarkTools v1.5.0

https://gist.github.com/ohno/880acda728d5747bc93f00203270a605

Discussion