🥑

juliaの数値積分を改良して,NeuralODEに使いたい

2025/02/17に公開

summary

前回の記事に続いて,julia の Neural ODEを使っていきます.

インターネットに落ちているNeural ODEの具体例は,
更新則はTsit5を使用しています.
https://docs.sciml.ai/DiffEqFlux/stable/examples/neural_ode/

Tsit5()じゃなくて,違う自分作ったやつで更新してみたい!!というわけで,色々触ってみたという記事です.

結論から言って,思うようにいきませんでした.
数値更新則のコードにおけるcacheとかの扱い方や,自動微分のため計算グラフが理解できておらず,っていう感じです.

わかる方いらっしゃれば,ぜひ教えていただきたいです.
これ読んだらわかるよ,とかあれば教えていただきたいです.

前提

今回は,DiffEqFlux.jlのNeural ODEにおける更新則を変えようとするため,
OrdinaryDiffEq内に実装されているのに倣って,新しく実装しようと思っています

参考リンク
Adding new algorithms to OrdinaryDiffEq
https://docs.sciml.ai/DiffEqDevDocs/stable/contributing/adding_algorithms/#Adding-new-algorithms-to-OrdinaryDiffEq

Euler法

簡単な例として,Euler法のソースコードは以下になります.(OrdinaryDiffEqLowOrderRKというlibraryの中にあります)
初期化関数(initialize!)と1ステップ更新する関数(perform_step!)だけを抜粋して,載せています.

juliaの関数において,!マークがつくと,呼び出し時に少なくとも1つ(多くの場合最初のもの)の引数が変更されることを視覚的に示すためだそうです.(in placeな関数ってこと)以下参考
https://docs.julialang.org/en/v1/manual/functions/

コード
fixed_timestep_perform_step.jl
function initialize!(integrator, cache::EulerConstantCache)
    integrator.kshortsize = 2
    integrator.k = typeof(integrator.k)(undef, integrator.kshortsize)
    integrator.fsalfirst = integrator.f(integrator.uprev, integrator.p, integrator.t) # Pre-start fsal
    OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)

    # Avoid undefined entries if k is an array of arrays
    integrator.fsallast = zero(integrator.fsalfirst)
    integrator.k[1] = integrator.fsalfirst
    integrator.k[2] = integrator.fsallast
end

function perform_step!(integrator, cache::EulerConstantCache, repeat_step = false)
    @unpack t, dt, uprev, f, p = integrator
    @muladd u = @.. broadcast=false uprev+dt * integrator.fsalfirst
    k = f(u, p, t + dt) # For the interpolation, needs k at the updated point
    OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
    integrator.fsallast = k
    integrator.k[1] = integrator.fsalfirst
    integrator.k[2] = integrator.fsallast
    integrator.u = u
end

get_fsalfirstlast(cache::EulerCache, u) = (cache.fsalfirst, cache.k)

function initialize!(integrator, cache::EulerCache)
    integrator.kshortsize = 2
    @unpack k, fsalfirst = cache
    integrator.fsalfirst = fsalfirst
    integrator.fsallast = k
    resize!(integrator.k, integrator.kshortsize)
    integrator.k[1] = integrator.fsalfirst
    integrator.k[2] = integrator.fsallast
    integrator.f(integrator.fsalfirst, integrator.uprev, integrator.p, integrator.t) # For the interpolation, needs k at the updated point
    OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
end

function perform_step!(integrator, cache::EulerCache, repeat_step = false)
    @unpack t, dt, uprev, u, f, p = integrator
    @muladd @.. broadcast=false u=uprev + dt * integrator.fsalfirst
    f(integrator.fsallast, u, p, t + dt) # For the interpolation, needs k at the updated point
    OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
end

EulerCacheとEulerConstantCache(可変と固定の場合?)によって,二通りの実装されています.
OrdinaryDiffEq内に実装されている手法は,このような感じで,後は,cache等の定義がされているファイルが数個ある感じです.

このように,二通りの実装されていることによって, NeuralODEをsloveする時における,forwardの数値積分時はConstantCacheの方を,勾配逆伝播のための数値積分の時は可変Cacheの方を使用するようになっていました.

変えて遊んでみる

Eulerのところを,dddEulerとかの適当な名前に変えておきます.

今回は,
@muladd u = @.. broadcast=false uprev+dt * integrator.fsalfirst
の更新の部分を以下のように,適当に変えてみました.
u[1]=uprev[1] + dt * integrator.fsalfirst[2]
u[2]=uprev[2] - dt * integrator.fsalfirst[1]

コード
function perform_step!(integrator, cache::dddEulerConstantCache, repeat_step = false)
    @unpack t, dt, u, uprev, f, p = integrator
-    @muladd u = @.. broadcast=false uprev+dt * integrator.fsalfirst
+    u[1]=uprev[1] + dt * integrator.fsalfirst[2]
+    u[2]=uprev[2] - dt * integrator.fsalfirst[1]
    k = f(u, p, t + dt) # For the interpolation, needs k at the updated point
    OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
    integrator.fsallast = k
    integrator.k[1] = integrator.fsalfirst
    integrator.k[2] = integrator.fsallast
    integrator.u = u
end

get_fsalfirstlast(cache::dddEulerCache, u) = (cache.fsalfirst, cache.k)

function initialize!(integrator, cache::dddEulerCache)
    integrator.kshortsize = 2
    @unpack k, fsalfirst = cache
    integrator.fsalfirst = fsalfirst
    integrator.fsallast = k
    resize!(integrator.k, integrator.kshortsize)
    integrator.k[1] = integrator.fsalfirst
    integrator.k[2] = integrator.fsallast
    integrator.f(integrator.fsalfirst, integrator.uprev, integrator.p, integrator.t) # For the interpolation, needs k at the updated point
    OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
end

function perform_step!(integrator, cache::dddEulerCache, repeat_step = false)
    @unpack t, dt, uprev, u, f, p = integrator
-    @muladd @.. broadcast=false u=uprev + dt * integrator.fsalfirst
+    u[1]=uprev[1] + dt * integrator.fsalfirst[2]
+    u[2]=uprev[2] - dt * integrator.fsalfirst[1]
    f(integrator.fsallast, u, p, t + dt) # For the interpolation, needs k at the updated point
    OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
end

これだけで,誤差が伝播しなくなります.以下こんな感じ.(最後の全体コード出力例)
Current loss after 10 iterations: 14.3037370496528
Current loss after 20 iterations: 14.3037370496528

gptによると,代入だから自動微分の追跡できないよとか言われます.
.= とかも使ってみましたが,ダメでした.
後,
new_u1=uprev[1] + dt * integrator.fsalfirst[2]
new_u2=uprev[2] - dt * integrator.fsalfirst[1]
@views u[1:2] = [new_u1, new_u2]
とかも試しましたが,誤差が伝播しなくなります

ちなみに,公式ドキュメントで追加した手法で試すと,うまくいきますが,自分のやりたいことはそうじゃないんですよね,,,,別々の要素ごとに更新したいんですよね.

とりあえず,こんな感じです.
最後に,使ったコードだけ貼っておきます.

コメントいただけたら嬉しいです.

全体コード

コード
import OrdinaryDiffEqCore: alg_order, isfsal, beta2_default, beta1_default,
                           alg_stability_size,
                           ssp_coefficient, OrdinaryDiffEqAlgorithm,
                           OrdinaryDiffEqExponentialAlgorithm,
                           explicit_rk_docstring, generic_solver_docstring,
                           trivial_limiter!,
                           OrdinaryDiffEqAdaptiveAlgorithm,
                           unwrap_alg, @unpack, initialize!, perform_step!,
                           calculate_residuals,
                           calculate_residuals!, _ode_addsteps!, @OnDemandTableauExtract,
                           constvalue,
                           OrdinaryDiffEqMutableCache, uses_uprev,
                           OrdinaryDiffEqConstantCache, @fold,
                           @cache, CompiledFloats, alg_cache, CompositeAlgorithm,
                           copyat_or_push!,
                           AutoAlgSwitch, _ode_interpolant, _ode_interpolant!,
                           accept_step_controller, DerivativeOrderNotPossibleError,
                           du_cache, u_cache, get_fsalfirstlast
using SciMLBase
using RecursiveArrayTools
import SciMLBase: full_cache
import MuladdMacro: @muladd
import FastBroadcast: @..
import LinearAlgebra: norm
import RecursiveArrayTools: recursivefill!, recursive_unitless_bottom_eltype
# import Static: False
using DiffEqBase: @def, @tight_loop_macros
import OrdinaryDiffEqCore

# using ADTypes: AutoForwardDiff, AutoZygote
using ForwardDiff: ForwardDiff

using ArgCheck: @argcheck
using ADTypes: AutoForwardDiff, AutoZygote
using Zygote: Zygote
using Compat: @compat
using ConcreteStructs: @concrete
using ChainRulesCore: ChainRulesCore, @non_differentiable, @ignore_derivatives
# using Markdown: @doc_str
# using Random: AbstractRNG
using Static: Static

using Reexport
@reexport using DiffEqBase

using Lux, Boltz, Optimisers, IterTools, OrdinaryDiffEq
using Random, Plots, ComponentArrays
using NNlib, DiffEqFlux, Optimization, Optimisers, OptimizationOptimisers

# custom method 
struct dddEuler <: OrdinaryDiffEqAlgorithm end

@cache struct dddEulerCache{uType, rateType} <: OrdinaryDiffEqMutableCache
    u::uType
    uprev::uType
    tmp::uType
    k::rateType
    fsalfirst::rateType
end

function alg_cache(alg::dddEuler, u, rate_prototype, ::Type{uEltypeNoUnits},
        ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
        dt, reltol, p, calck,
        ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
    dddEulerCache(u, uprev, zero(u), zero(rate_prototype), zero(rate_prototype))
end

struct dddEulerConstantCache <: OrdinaryDiffEqConstantCache end

function alg_cache(alg::dddEuler, u, rate_prototype, ::Type{uEltypeNoUnits},
        ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
        dt, reltol, p, calck,
        ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
    dddEulerConstantCache()
end


function initialize!(integrator, cache::dddEulerConstantCache)
    integrator.kshortsize = 2
    integrator.k = typeof(integrator.k)(undef, integrator.kshortsize)
    integrator.fsalfirst = integrator.f(integrator.uprev, integrator.p, integrator.t) # Pre-start fsal
    OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)

    # Avoid undefined entries if k is an array of arrays
    integrator.fsallast = zero(integrator.fsalfirst)
    integrator.k[1] = integrator.fsalfirst
    integrator.k[2] = integrator.fsallast
end

function perform_step!(integrator, cache::dddEulerConstantCache, repeat_step = false)
    @unpack t, dt, u, uprev, f, p = integrator
    # @muladd u = @.. broadcast=false uprev+dt * integrator.fsalfirst
    u[1]=uprev[1] + dt * integrator.fsalfirst[2]
    u[2]=uprev[2] - dt * integrator.fsalfirst[1]
    k = f(u, p, t + dt) # For the interpolation, needs k at the updated point
    OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
    integrator.fsallast = k
    integrator.k[1] = integrator.fsalfirst
    integrator.k[2] = integrator.fsallast
    integrator.u = u
end

get_fsalfirstlast(cache::dddEulerCache, u) = (cache.fsalfirst, cache.k)

function initialize!(integrator, cache::dddEulerCache)
    integrator.kshortsize = 2
    @unpack k, fsalfirst = cache
    integrator.fsalfirst = fsalfirst
    integrator.fsallast = k
    resize!(integrator.k, integrator.kshortsize)
    integrator.k[1] = integrator.fsalfirst
    integrator.k[2] = integrator.fsallast
    integrator.f(integrator.fsalfirst, integrator.uprev, integrator.p, integrator.t) # For the interpolation, needs k at the updated point
    OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
end

function perform_step!(integrator, cache::dddEulerCache, repeat_step = false)
    @unpack t, dt, uprev, u, f, p = integrator
    # @muladd @.. broadcast=false u=uprev + dt * integrator.fsalfirst
    u[1]=uprev[1] + dt * integrator.fsalfirst[2]
    u[2]=uprev[2] - dt * integrator.fsalfirst[1]
    f(integrator.fsallast, u, p, t + dt) # For the interpolation, needs k at the updated point
    OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
end


# using Lux, Optim, Plots, AdvancedHMC, Random, OrdinaryDiffEq, ComponentArrays, DiffEqFlux

using ComponentArrays, Lux, DiffEqFlux, OrdinaryDiffEq, Optimization, OptimizationOptimisers, Random, Plots


function lotka_volterra!(du, u, p, t)
  x, y = u
  α, β, δ, γ = p
  du[1] = dx = α*x - β*x*y
  du[2] = dy = -δ*y + γ*x*y
end

# Initial condition
u0 = [1.0, 1.0]

# Simulation interval and intermediary points
tspan = (0.0, 1.0)
tsteps = 0.0:0.1:1.0
# datasize = length(tsteps)

# LV equation parameter. p = [α, β, δ, γ]
paras = [1.5, 1.0, 3.0, 1.0]

# Setup the ODE problem, then solve
prob_ode = ODEProblem(lotka_volterra!, u0, tspan, paras)
mean_ode_data = ComponentArray{Float64}(solve(prob_ode, Tsit5(), saveat = tsteps))
ode_data = mean_ode_data .+ 0.1 .* randn(size(mean_ode_data)..., 1)

dudt2 = Chain(Dense(2, 20, relu),
                Dense(20, 20, relu),
                Dense(20, 20, relu),
                Dense(20, 2))

rng = Random.default_rng(10)
p, st = Lux.setup(rng, dudt2)
prob_neuralode = NeuralODE(dudt2, tspan, dddEuler(), dt=0.10, saveat = tsteps)

function predict_neuralode(p)
    Array(prob_neuralode(u0, p, st)[1])
end

function loss_neuralode(p)
    pred = predict_neuralode(p)
    loss = sum(abs2, ode_data .- pred)
    return loss
end
# loss

pinit = ComponentArray(p)

loss_list = []
# Callback function to observe training
callback = function (p, l; doplot = false)
    push!(loss_list, l)
    if length(loss_list) % 10 == 0
        println("Current loss after $(length(loss_list)) iterations: $(loss_list[end])")
    end
    # plot current prediction against data
    if doplot
        plt = scatter(tsteps, ode_data[1, :]; label = "data")
        scatter!(plt, tsteps, ode_data[2, :]; label = "data2")
        scatter!(plt, tsteps, pred[1, :]; label = "prediction")
        scatter!(plt, tsteps, pred[2, :]; label = "prediction2")
        display(plot(plt))
    end
    return false
end


# Train using the Adam optimizer
adtype = Optimization.AutoZygote()

optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pinit)

result_neuralode = Optimization.solve(
    optprob, OptimizationOptimisers.Adam(0.007); callback = callback, maxiters = 100)

############################### PLOT LOSS ##################################
scatter(loss_list[20:end], yscale=:log10, xlabel = "interation", ylabel = "Loss")

Discussion