juliaの数値積分を改良して,NeuralODEに使いたい
summary
前回の記事に続いて,julia の Neural ODEを使っていきます.
インターネットに落ちているNeural ODEの具体例は,
更新則はTsit5を使用しています.
Tsit5()じゃなくて,違う自分作ったやつで更新してみたい!!というわけで,色々触ってみたという記事です.
結論から言って,思うようにいきませんでした.
数値更新則のコードにおけるcacheとかの扱い方や,自動微分のため計算グラフが理解できておらず,っていう感じです.
わかる方いらっしゃれば,ぜひ教えていただきたいです.
これ読んだらわかるよ,とかあれば教えていただきたいです.
前提
今回は,DiffEqFlux.jlのNeural ODEにおける更新則を変えようとするため,
OrdinaryDiffEq内に実装されているのに倣って,新しく実装しようと思っています
参考リンク
Adding new algorithms to OrdinaryDiffEq
Euler法
簡単な例として,Euler法のソースコードは以下になります.(OrdinaryDiffEqLowOrderRKというlibraryの中にあります)
初期化関数(initialize!)と1ステップ更新する関数(perform_step!)だけを抜粋して,載せています.
juliaの関数において,!マークがつくと,呼び出し時に少なくとも1つ(多くの場合最初のもの)の引数が変更されることを視覚的に示すためだそうです.(in placeな関数ってこと)以下参考
コード
function initialize!(integrator, cache::EulerConstantCache)
integrator.kshortsize = 2
integrator.k = typeof(integrator.k)(undef, integrator.kshortsize)
integrator.fsalfirst = integrator.f(integrator.uprev, integrator.p, integrator.t) # Pre-start fsal
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
# Avoid undefined entries if k is an array of arrays
integrator.fsallast = zero(integrator.fsalfirst)
integrator.k[1] = integrator.fsalfirst
integrator.k[2] = integrator.fsallast
end
function perform_step!(integrator, cache::EulerConstantCache, repeat_step = false)
@unpack t, dt, uprev, f, p = integrator
@muladd u = @.. broadcast=false uprev+dt * integrator.fsalfirst
k = f(u, p, t + dt) # For the interpolation, needs k at the updated point
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
integrator.fsallast = k
integrator.k[1] = integrator.fsalfirst
integrator.k[2] = integrator.fsallast
integrator.u = u
end
get_fsalfirstlast(cache::EulerCache, u) = (cache.fsalfirst, cache.k)
function initialize!(integrator, cache::EulerCache)
integrator.kshortsize = 2
@unpack k, fsalfirst = cache
integrator.fsalfirst = fsalfirst
integrator.fsallast = k
resize!(integrator.k, integrator.kshortsize)
integrator.k[1] = integrator.fsalfirst
integrator.k[2] = integrator.fsallast
integrator.f(integrator.fsalfirst, integrator.uprev, integrator.p, integrator.t) # For the interpolation, needs k at the updated point
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
end
function perform_step!(integrator, cache::EulerCache, repeat_step = false)
@unpack t, dt, uprev, u, f, p = integrator
@muladd @.. broadcast=false u=uprev + dt * integrator.fsalfirst
f(integrator.fsallast, u, p, t + dt) # For the interpolation, needs k at the updated point
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
end
EulerCacheとEulerConstantCache(可変と固定の場合?)によって,二通りの実装されています.
OrdinaryDiffEq内に実装されている手法は,このような感じで,後は,cache等の定義がされているファイルが数個ある感じです.
このように,二通りの実装されていることによって, NeuralODEをsloveする時における,forwardの数値積分時はConstantCacheの方を,勾配逆伝播のための数値積分の時は可変Cacheの方を使用するようになっていました.
変えて遊んでみる
Eulerのところを,dddEulerとかの適当な名前に変えておきます.
今回は,
@muladd u = @.. broadcast=false uprev+dt * integrator.fsalfirst
の更新の部分を以下のように,適当に変えてみました.
u[1]=uprev[1] + dt * integrator.fsalfirst[2]
u[2]=uprev[2] - dt * integrator.fsalfirst[1]
コード
function perform_step!(integrator, cache::dddEulerConstantCache, repeat_step = false)
@unpack t, dt, u, uprev, f, p = integrator
- @muladd u = @.. broadcast=false uprev+dt * integrator.fsalfirst
+ u[1]=uprev[1] + dt * integrator.fsalfirst[2]
+ u[2]=uprev[2] - dt * integrator.fsalfirst[1]
k = f(u, p, t + dt) # For the interpolation, needs k at the updated point
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
integrator.fsallast = k
integrator.k[1] = integrator.fsalfirst
integrator.k[2] = integrator.fsallast
integrator.u = u
end
get_fsalfirstlast(cache::dddEulerCache, u) = (cache.fsalfirst, cache.k)
function initialize!(integrator, cache::dddEulerCache)
integrator.kshortsize = 2
@unpack k, fsalfirst = cache
integrator.fsalfirst = fsalfirst
integrator.fsallast = k
resize!(integrator.k, integrator.kshortsize)
integrator.k[1] = integrator.fsalfirst
integrator.k[2] = integrator.fsallast
integrator.f(integrator.fsalfirst, integrator.uprev, integrator.p, integrator.t) # For the interpolation, needs k at the updated point
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
end
function perform_step!(integrator, cache::dddEulerCache, repeat_step = false)
@unpack t, dt, uprev, u, f, p = integrator
- @muladd @.. broadcast=false u=uprev + dt * integrator.fsalfirst
+ u[1]=uprev[1] + dt * integrator.fsalfirst[2]
+ u[2]=uprev[2] - dt * integrator.fsalfirst[1]
f(integrator.fsallast, u, p, t + dt) # For the interpolation, needs k at the updated point
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
end
これだけで,誤差が伝播しなくなります.以下こんな感じ.(最後の全体コード出力例)
Current loss after 10 iterations: 14.3037370496528
Current loss after 20 iterations: 14.3037370496528
gptによると,代入だから自動微分の追跡できないよとか言われます.
.= とかも使ってみましたが,ダメでした.
後,
new_u1=uprev[1] + dt * integrator.fsalfirst[2]
new_u2=uprev[2] - dt * integrator.fsalfirst[1]
@views u[1:2] = [new_u1, new_u2]
とかも試しましたが,誤差が伝播しなくなります
ちなみに,公式ドキュメントで追加した手法で試すと,うまくいきますが,自分のやりたいことはそうじゃないんですよね,,,,別々の要素ごとに更新したいんですよね.
とりあえず,こんな感じです.
最後に,使ったコードだけ貼っておきます.
コメントいただけたら嬉しいです.
全体コード
コード
import OrdinaryDiffEqCore: alg_order, isfsal, beta2_default, beta1_default,
alg_stability_size,
ssp_coefficient, OrdinaryDiffEqAlgorithm,
OrdinaryDiffEqExponentialAlgorithm,
explicit_rk_docstring, generic_solver_docstring,
trivial_limiter!,
OrdinaryDiffEqAdaptiveAlgorithm,
unwrap_alg, @unpack, initialize!, perform_step!,
calculate_residuals,
calculate_residuals!, _ode_addsteps!, @OnDemandTableauExtract,
constvalue,
OrdinaryDiffEqMutableCache, uses_uprev,
OrdinaryDiffEqConstantCache, @fold,
@cache, CompiledFloats, alg_cache, CompositeAlgorithm,
copyat_or_push!,
AutoAlgSwitch, _ode_interpolant, _ode_interpolant!,
accept_step_controller, DerivativeOrderNotPossibleError,
du_cache, u_cache, get_fsalfirstlast
using SciMLBase
using RecursiveArrayTools
import SciMLBase: full_cache
import MuladdMacro: @muladd
import FastBroadcast: @..
import LinearAlgebra: norm
import RecursiveArrayTools: recursivefill!, recursive_unitless_bottom_eltype
# import Static: False
using DiffEqBase: @def, @tight_loop_macros
import OrdinaryDiffEqCore
# using ADTypes: AutoForwardDiff, AutoZygote
using ForwardDiff: ForwardDiff
using ArgCheck: @argcheck
using ADTypes: AutoForwardDiff, AutoZygote
using Zygote: Zygote
using Compat: @compat
using ConcreteStructs: @concrete
using ChainRulesCore: ChainRulesCore, @non_differentiable, @ignore_derivatives
# using Markdown: @doc_str
# using Random: AbstractRNG
using Static: Static
using Reexport
@reexport using DiffEqBase
using Lux, Boltz, Optimisers, IterTools, OrdinaryDiffEq
using Random, Plots, ComponentArrays
using NNlib, DiffEqFlux, Optimization, Optimisers, OptimizationOptimisers
# custom method
struct dddEuler <: OrdinaryDiffEqAlgorithm end
@cache struct dddEulerCache{uType, rateType} <: OrdinaryDiffEqMutableCache
u::uType
uprev::uType
tmp::uType
k::rateType
fsalfirst::rateType
end
function alg_cache(alg::dddEuler, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
dddEulerCache(u, uprev, zero(u), zero(rate_prototype), zero(rate_prototype))
end
struct dddEulerConstantCache <: OrdinaryDiffEqConstantCache end
function alg_cache(alg::dddEuler, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
dddEulerConstantCache()
end
function initialize!(integrator, cache::dddEulerConstantCache)
integrator.kshortsize = 2
integrator.k = typeof(integrator.k)(undef, integrator.kshortsize)
integrator.fsalfirst = integrator.f(integrator.uprev, integrator.p, integrator.t) # Pre-start fsal
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
# Avoid undefined entries if k is an array of arrays
integrator.fsallast = zero(integrator.fsalfirst)
integrator.k[1] = integrator.fsalfirst
integrator.k[2] = integrator.fsallast
end
function perform_step!(integrator, cache::dddEulerConstantCache, repeat_step = false)
@unpack t, dt, u, uprev, f, p = integrator
# @muladd u = @.. broadcast=false uprev+dt * integrator.fsalfirst
u[1]=uprev[1] + dt * integrator.fsalfirst[2]
u[2]=uprev[2] - dt * integrator.fsalfirst[1]
k = f(u, p, t + dt) # For the interpolation, needs k at the updated point
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
integrator.fsallast = k
integrator.k[1] = integrator.fsalfirst
integrator.k[2] = integrator.fsallast
integrator.u = u
end
get_fsalfirstlast(cache::dddEulerCache, u) = (cache.fsalfirst, cache.k)
function initialize!(integrator, cache::dddEulerCache)
integrator.kshortsize = 2
@unpack k, fsalfirst = cache
integrator.fsalfirst = fsalfirst
integrator.fsallast = k
resize!(integrator.k, integrator.kshortsize)
integrator.k[1] = integrator.fsalfirst
integrator.k[2] = integrator.fsallast
integrator.f(integrator.fsalfirst, integrator.uprev, integrator.p, integrator.t) # For the interpolation, needs k at the updated point
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
end
function perform_step!(integrator, cache::dddEulerCache, repeat_step = false)
@unpack t, dt, uprev, u, f, p = integrator
# @muladd @.. broadcast=false u=uprev + dt * integrator.fsalfirst
u[1]=uprev[1] + dt * integrator.fsalfirst[2]
u[2]=uprev[2] - dt * integrator.fsalfirst[1]
f(integrator.fsallast, u, p, t + dt) # For the interpolation, needs k at the updated point
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
end
# using Lux, Optim, Plots, AdvancedHMC, Random, OrdinaryDiffEq, ComponentArrays, DiffEqFlux
using ComponentArrays, Lux, DiffEqFlux, OrdinaryDiffEq, Optimization, OptimizationOptimisers, Random, Plots
function lotka_volterra!(du, u, p, t)
x, y = u
α, β, δ, γ = p
du[1] = dx = α*x - β*x*y
du[2] = dy = -δ*y + γ*x*y
end
# Initial condition
u0 = [1.0, 1.0]
# Simulation interval and intermediary points
tspan = (0.0, 1.0)
tsteps = 0.0:0.1:1.0
# datasize = length(tsteps)
# LV equation parameter. p = [α, β, δ, γ]
paras = [1.5, 1.0, 3.0, 1.0]
# Setup the ODE problem, then solve
prob_ode = ODEProblem(lotka_volterra!, u0, tspan, paras)
mean_ode_data = ComponentArray{Float64}(solve(prob_ode, Tsit5(), saveat = tsteps))
ode_data = mean_ode_data .+ 0.1 .* randn(size(mean_ode_data)..., 1)
dudt2 = Chain(Dense(2, 20, relu),
Dense(20, 20, relu),
Dense(20, 20, relu),
Dense(20, 2))
rng = Random.default_rng(10)
p, st = Lux.setup(rng, dudt2)
prob_neuralode = NeuralODE(dudt2, tspan, dddEuler(), dt=0.10, saveat = tsteps)
function predict_neuralode(p)
Array(prob_neuralode(u0, p, st)[1])
end
function loss_neuralode(p)
pred = predict_neuralode(p)
loss = sum(abs2, ode_data .- pred)
return loss
end
# loss
pinit = ComponentArray(p)
loss_list = []
# Callback function to observe training
callback = function (p, l; doplot = false)
push!(loss_list, l)
if length(loss_list) % 10 == 0
println("Current loss after $(length(loss_list)) iterations: $(loss_list[end])")
end
# plot current prediction against data
if doplot
plt = scatter(tsteps, ode_data[1, :]; label = "data")
scatter!(plt, tsteps, ode_data[2, :]; label = "data2")
scatter!(plt, tsteps, pred[1, :]; label = "prediction")
scatter!(plt, tsteps, pred[2, :]; label = "prediction2")
display(plot(plt))
end
return false
end
# Train using the Adam optimizer
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pinit)
result_neuralode = Optimization.solve(
optprob, OptimizationOptimisers.Adam(0.007); callback = callback, maxiters = 100)
############################### PLOT LOSS ##################################
scatter(loss_list[20:end], yscale=:log10, xlabel = "interation", ylabel = "Loss")
Discussion