🍣

Javascriptで8次のルンゲクッタ

2022/12/12に公開

常微分方程式の数値解法といえばルンゲクッタ。
ルンゲクッタといえば古典的と呼ばれる4段4次。
検索すれば4段4次のサンプルコードは山程転がっている。

ところがそれより高次となるとなかなか見つからない。
見つかってもFORTRANかPythonばかりなので、JavaScript(TypeScript)で車輪の再発明した際の備忘録。

埋め込み型ルンゲクッタ

求める精度に対して最適な刻み幅を計算していい感じに調節することができる。
MATLABのode45とかが内部でやってるのがこれ。
詳しい説明はこちらが参考になる。

刻み幅の最適化アルゴリズムは常微分方程式の数値解法 I 基礎編にあるものを採用する[1]

ドルマン=プリンス法(Dormand-Prince Method)

埋め込み型ルンゲクッタの一種で、特に8次の公式が有名。
dop853やdopri853で検索するとその精度の高さがわかるだろう。

http://www.unige.ch/~hairer/software.html
Ernst.HairerによるFORTRANのコードはジュネーブ大学が公開している。
ButcherTableの係数はこのコードを参考にさせてもらった。

本記事のコード自体はお好きにコピペをどうぞだが、上記のコードは2条項BSDライセンスである[2]

実装

/* eslint-disable @typescript-eslint/no-loss-of-precision */
const zip = (...args: number[][]): number[][] =>
    Array.from(Array(Math.min(...args.map((v) => v.length))), (_, i) => args.map((v) => v[i]));

const adds = (...args: number[][]): number[] => {
    return zip(...args).map((v) => {
        return sum(v);
    });
};

const multiple = (a: number, b: number[]): number[] => {
    return b.map((v) => {
        return a * v
    });
};

const estimateError = (x: number[], x_: number[], delta: number[], atol: number[], rtol: number[]): number => {
    let err = 0;
    for (let i = 0; i < delta.length; i++) {
        const sc = atol[i] + Math.max(Math.abs(x[i]), Math.abs(x_[i])) * rtol[i];
        err += (delta[i] / sc) ** 2;
    }
    err = Math.sqrt(err / delta.length);
    return err;
};

const _ode = (f: (x: number[], t: number, h: number, adaptive?: { atol: number[], rtol: number[] }) => [number[], number | null], x: number[], t: number, h: number, adaptive?: {
    atol: number[],
    rtol: number[],
    maxiter?: number,
    fac?: number,
    facmax?: number,
    facmin?: number,
}): [number[], number, number] => {
    /*
        Reference
        Ernst Hairer, Gerhard Wanner, Syvert P. Nørsett. 
        Solving Ordinary Differential Equations I: Nonstiff Problems. 
        Springer.
    */
    if (!adaptive) {
        const [x_] = f(x, t, h);
        return [x_, t + h, h];
    }

    const maxiter = adaptive.maxiter ?? 100;
    let facmax = 5;
    if (adaptive.facmax && adaptive.facmax > 0) facmax = adaptive.facmax;
    let facmin = 0.1;
    if (adaptive.facmin && adaptive.facmin > 0) facmin = adaptive.facmin;
    let fac = 0.9;
    if (adaptive.fac && adaptive.fac > 0) fac = adaptive.fac;

    let h_ = h;
    let i = 0;

    while (i++ < maxiter) {
        const [x_, err] = f(x, t, h_, { atol: adaptive.atol, rtol: adaptive.rtol });
        if (!err) break;
        const t_ = t + h_;
        h_ *= Math.min(facmax, Math.max(facmin, fac * ((1 / err) ** 0.2)));
        if (err < 1) {
            return [x_, t_, h_];
        }
    }

    throw new Error();
};

const rkf45 = (f: (x: number[], t: number) => number[], x: number[], t: number, h: number, adaptive?: boolean | {
    atol?: number | number[];
    rtol?: number | number[];
    maxiter?: number,
    fac?: number,
    facmax?: number,
    facmin?: number,
}): [number[], number, number] => {
    /*
        Runge-Kutta-Fehlberg
    */
    const calc = (x: number[], t: number, h: number, adaptive?: { atol: number[], rtol: number[] }): [number[], number | null] => {
        const k1 = f(x, t);
        const k2 = f(adds(x, multiple(1 / 4 * h, k1)), t + 1 / 4 * h);
        const k3 = f(adds(x, multiple(3 / 32 * h, k1), multiple(9 / 32 * h, k2)), t + 3 / 8 * h);
        const k4 = f(adds(x, multiple(1932 / 2197 * h, k1), multiple(-7200 / 2197 * h, k2), multiple(7296 / 2197 * h, k3)), t + 12 / 13 * h);
        const k5 = f(adds(x, multiple(439 / 216 * h, k1), multiple(-8 * h, k2), multiple(3680 / 513 * h, k3), multiple(-845 / 4104 * h, k4)), t + h);
        const k6 = f(adds(x, multiple(-8 / 27 * h, k1), multiple(2 * h, k2), multiple(-3544 / 2565 * h, k3), multiple(1859 / 4104 * h, k4), multiple(-11 / 40 * h, k5)), t + 0.5 * h);
        const x_ = adds(x, multiple(25 / 216 * h, k1), multiple(1408 / 2565 * h, k3), multiple(2197 / 4104 * h, k4), multiple(-1 / 5 * h, k5));
        if (!adaptive) return [x_, null]

        const delta = adds(multiple(71 / 57600 * h, k1), multiple(-128 / 4275 * h, k3), multiple(-2197 / 75240 * h, k4), multiple(1 / 50 * h, k5), multiple(2 / 55 * h, k6));

        return [x_, estimateError(x, x_, delta, adaptive.atol, adaptive.rtol)]
    };

    const options = adaptive !== false ? (() => {
        const atol: number[] = [];
        const rtol: number[] = [];
        const userOptions = typeof adaptive === "boolean" ? undefined : adaptive
        for (let i = 0; i < x.length; i++) {
            const a = userOptions ? Array.isArray(userOptions.atol) ? userOptions.atol[i] : userOptions.atol : undefined;
            const r = userOptions ? Array.isArray(userOptions.rtol) ? userOptions.rtol[i] : userOptions.rtol : undefined;
            atol.push(((a === 0 && r === 0) || typeof a === "undefined" || a < 0) ? 1e-6 : a);
            rtol.push((typeof r === "undefined" || r < 0) ? 1e-3 : r);
        }

        return {
            atol, rtol, maxiter: userOptions?.maxiter, fac: userOptions?.fac, facmax: userOptions?.facmax, facmin: userOptions?.facmin,
        }
    })() : undefined;
    return _ode(calc, x, t, h, options);
};

const dopri5 = (f: (x: number[], t: number) => number[], x: number[], t: number, h: number, adaptive?: boolean | {
    atol?: number | number[];
    rtol?: number | number[];
    maxiter?: number,
    fac?: number,
    facmax?: number,
    facmin?: number,
}): [number[], number, number] => {
    /*
        Dormand-Prince method order 5
    */
    const calc = (x: number[], t: number, h: number, adaptive?: { atol: number[], rtol: number[] }): [number[], number | null] => {
        const k1 = f(x, t);
        const k2 = f(adds(x, multiple(1 / 5 * h, k1)), t + 1 / 5 * h);
        const k3 = f(adds(x, multiple(3 / 40 * h, k1), multiple(9 / 40 * h, k2)), t + 3 / 10 * h);
        const k4 = f(adds(x, multiple(44 / 45 * h, k1), multiple(-56 / 15 * h, k2), multiple(32 / 9 * h, k3)), t + 4 / 5 * h);
        const k5 = f(adds(x, multiple(19372 / 6561 * h, k1), multiple(-25360 / 2187 * h, k2), multiple(64448 / 6561 * h, k3), multiple(-212 / 729 * h, k4)), t + 8 / 9 * h);
        const k6 = f(adds(x, multiple(9017 / 3168 * h, k1), multiple(-355 / 33 * h, k2), multiple(46732 / 5247 * h, k3), multiple(49 / 176 * h, k4), multiple(-5103 / 18656 * h, k5)), t + h);
        const k7 = f(adds(x, multiple(35 / 384 * h, k1), multiple(500 / 1113 * h, k3), multiple(125 / 192 * h, k4), multiple(-2187 / 6784 * h, k5), multiple(11 / 84 * h, k6)), t + h);

        const x_ = adds(x, multiple(35 / 384 * h, k1), multiple(500 / 1113 * h, k3), multiple(125 / 192 * h, k4), multiple(-2187 / 6784 * h, k5), multiple(11 / 84 * h, k6));
        if (!adaptive) return [x_, null];

        const delta = adds(multiple(71 / 57600 * h, k1), multiple(-71 / 16695 * h, k3), multiple(71 / 1920 * h, k4), multiple(-17253 / 339200 * h, k5), multiple(22 / 525 * h, k6), multiple(-1 / 40 * h, k7));

        return [x_, estimateError(x, x_, delta, adaptive.atol, adaptive.rtol)]
    };

    const options = adaptive !== false ? (() => {
        const atol: number[] = [];
        const rtol: number[] = [];
        const userOptions = typeof adaptive === "boolean" ? undefined : adaptive
        for (let i = 0; i < x.length; i++) {
            const a = userOptions ? Array.isArray(userOptions.atol) ? userOptions.atol[i] : userOptions.atol : undefined;
            const r = userOptions ? Array.isArray(userOptions.rtol) ? userOptions.rtol[i] : userOptions.rtol : undefined;
            atol.push(((a === 0 && r === 0) || typeof a === "undefined" || a < 0) ? 1e-12 : a);
            rtol.push((typeof r === "undefined" || r < 0) ? 1e-6 : r);
        }

        return {
            atol, rtol, maxiter: userOptions?.maxiter, fac: userOptions?.fac, facmax: userOptions?.facmax, facmin: userOptions?.facmin,
        }
    })() : undefined;
    return _ode(calc, x, t, h, options);
};

const dopri853 = (f: (x: number[], t: number,) => number[], x: number[], t: number, h: number, adaptive?: boolean | {
    atol?: number | number[];
    rtol?: number | number[];
    maxiter?: number,
    fac?: number,
    facmax?: number,
    facmin?: number,
}): [number[], number, number] => {
    /*
        Dormand-Prince method order 8

        Reference
        dop853.f
        http://www.unige.ch/~hairer/software.html
    */
    const calc = (x: number[], t: number, h: number, adaptive?: { atol: number[], rtol: number[] }): [number[], number | null] => {
        const k1 = f(x, t);
        const k2 = f(adds(x, multiple(5.26001519587677318785587544488e-2 * h, k1)), t + 0.526001519587677318785587544488e-1 * h);
        const k3 = f(adds(x, multiple(1.97250569845378994544595329183e-2 * h, k1), multiple(5.91751709536136983633785987549e-2 * h, k2)), t + 0.789002279381515978178381316732e-1 * h);
        const k4 = f(adds(x, multiple(2.95875854768068491816892993775e-2 * h, k1), multiple(8.87627564304205475450678981324e-2 * h, k3)), t + 0.118350341907227396726757197510 * h);
        const k5 = f(adds(x, multiple(2.41365134159266685502369798665e-1 * h, k1), multiple(-8.84549479328286085344864962717e-1 * h, k3), multiple(9.24834003261792003115737966543e-1 * h, k4)), t + 0.281649658092772603273242802490 * h);
        const k6 = f(adds(x, multiple(3.7037037037037037037037037037e-2 * h, k1), multiple(1.70828608729473871279604482173e-1 * h, k4), multiple(1.25467687566822425016691814123e-1 * h, k5)), t + 0.333333333333333333333333333333 * h);
        const k7 = f(adds(x, multiple(3.7109375e-2 * h, k1), multiple(1.70252211019544039314978060272e-1 * h, k4), multiple(6.02165389804559606850219397283e-2 * h, k5), multiple(-1.7578125e-2 * h, k6)), t + 0.25 * h);
        const k8 = f(adds(x, multiple(3.70920001185047927108779319836e-2 * h, k1), multiple(1.70383925712239993810214054705e-1 * h, k4), multiple(1.07262030446373284651809199168e-1 * h, k5), multiple(-1.53194377486244017527936158236e-2 * h, k6), multiple(8.27378916381402288758473766002e-3 * h, k7)), t + 0.307692307692307692307692307692 * h);
        const k9 = f(adds(x, multiple(6.24110958716075717114429577812e-1 * h, k1), multiple(-3.36089262944694129406857109825 * h, k4), multiple(-8.68219346841726006818189891453e-1 * h, k5), multiple(2.75920996994467083049415600797e1 * h, k6), multiple(2.01540675504778934086186788979e1 * h, k7), multiple(-4.34898841810699588477366255144e1 * h, k8)), t + 0.651282051282051282051282051282 * h);
        const k10 = f(adds(x, multiple(4.77662536438264365890433908527e-1 * h, k1), multiple(-2.48811461997166764192642586468 * h, k4), multiple(-5.90290826836842996371446475743e-1 * h, k5), multiple(2.12300514481811942347288949897e1 * h, k6), multiple(1.52792336328824235832596922938e1 * h, k7), multiple(-3.32882109689848629194453265587e1 * h, k8), multiple(-2.03312017085086261358222928593e-2 * h, k9)), t + 0.6 * h);
        const k11 = f(adds(x, multiple(-9.3714243008598732571704021658e-1 * h, k1), multiple(5.18637242884406370830023853209 * h, k4), multiple(1.09143734899672957818500254654 * h, k5), multiple(-8.14978701074692612513997267357 * h, k6), multiple(-1.85200656599969598641566180701e1 * h, k7), multiple(2.27394870993505042818970056734e1 * h, k8), multiple(2.49360555267965238987089396762 * h, k9), multiple(-3.0467644718982195003823669022 * h, k10)), t + 0.857142857142857142857142857142 * h);
        const k12 = f(adds(x, multiple(2.27331014751653820792359768449 * h, k1), multiple(-1.05344954667372501984066689879e1 * h, k4), multiple(-2.00087205822486249909675718444 * h, k5), multiple(-1.79589318631187989172765950534e1 * h, k6), multiple(2.79488845294199600508499808837e1 * h, k7), multiple(-2.85899827713502369474065508674 * h, k8), multiple(-8.87285693353062954433549289258 * h, k9), multiple(1.23605671757943030647266201528e1 * h, k10), multiple(6.43392746015763530355970484046e-1 * h, k11)), t + h);

        const x_ = adds(x, multiple(5.42937341165687622380535766363e-2 * h, k1), multiple(4.45031289275240888144113950566 * h, k6), multiple(1.89151789931450038304281599044 * h, k7), multiple(-5.8012039600105847814672114227 * h, k8), multiple(3.1116436695781989440891606237e-1 * h, k9), multiple(-1.52160949662516078556178806805e-1 * h, k10), multiple(2.01365400804030348374776537501e-1 * h, k11), multiple(4.47106157277725905176885569043e-2 * h, k12));
        if (!adaptive) return [x_, null];

        const delta5 = adds(multiple(0.1312004499419488073250102996e-1, k1), multiple(-0.1225156446376204440720569753e+1, k6), multiple(-0.4957589496572501915214079952, k7), multiple(0.1664377182454986536961530415e+1, k8), multiple(-0.3503288487499736816886487290, k9), multiple(0.3341791187130174790297318841, k10), multiple(0.8192320648511571246570742613e-1, k11), multiple(-0.2235530786388629525884427845e-1, k12));
        const delta3 = adds(multiple(-0.189800754072407617468755659980, k1), multiple(4.45031289275240888144113950566, k6), multiple(1.89151789931450038304281599044, k7), multiple(-5.8012039600105847814672114227, k8), multiple(-0.422682321323791962932445679177, k9), multiple(-1.52160949662516078556178806805e-1, k10), multiple(2.01365400804030348374776537501e-1, k11), multiple(0.0226517921983608258118062039631, k12));
        let err5 = 0;
        let err3 = 0;
        for (let i = 0; i < x.length; i++) {
            const sc = adaptive.atol[i] + Math.max(Math.abs(x[i]), Math.abs(x_[i])) * adaptive.rtol[i];
            err5 += (delta5[i] / sc) ** 2;
            err3 += (delta3[i] / sc) ** 2;
        }

        const denominator = err5 !== 0 || err3 !== 0 ? err5 + 0.01 * err3 : 1;
        const err = Math.abs(h) * err5 * Math.sqrt(1 / (x.length * denominator))

        return [x_, err]
    };

    const options = adaptive !== false ? (() => {
        const atol: number[] = [];
        const rtol: number[] = [];
        const userOptions = typeof adaptive === "boolean" ? undefined : adaptive
        for (let i = 0; i < x.length; i++) {
            const a = userOptions ? Array.isArray(userOptions.atol) ? userOptions.atol[i] : userOptions.atol : undefined;
            const r = userOptions ? Array.isArray(userOptions.rtol) ? userOptions.rtol[i] : userOptions.rtol : undefined;
            atol.push(((a === 0 && r === 0) || typeof a === "undefined" || a < 0) ? 1e-12 : a);
            rtol.push((typeof r === "undefined" || r < 0) ? 1e-6 : r);
        }
        return {
            atol, rtol, maxiter: userOptions?.maxiter, fac: userOptions?.fac, facmax: userOptions?.facmax, facmin: userOptions?.facmin,
        }
    })() : undefined;
    return _ode(calc, x, t, h, options);
};

実装したのはRunge-Kutta-Fehlberg法のrkf45(4次)、ドルマンプリンス法のdopri5(5次), dopri853(8次)の3種類。
係数をタイポしていないか不安はあるが、後述の例題で大きく計算がずれていないことから大丈夫であろう。

  • ベクトル演算はそのためにわざわざライブラリを使いたくなかったので自前の簡易実装。
  • 丸め誤差や桁落ちへの対策はしていない。必要になったら適宜自分でやってほしい。
  • Javascriptの有効桁数は15桁だが、参考文献の数字(30桁)をママ使っている。15桁に丸めると精度に影響があるかは試していない。
  • 密出力は考慮していない。

例題

const gamma = 0.15;
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const f = (x: number[], t: number): number[] => {
    return [
        x[1],
        -2 * gamma * x[1] - x[0]
    ];
}

const trueValue = (t: number) => {
    return Math.exp(-gamma * t) * Math.cos(t * Math.sqrt(1 - gamma ** 2))
}

刻み幅固定(h=0.2)

let x1 = [1, -0.15];
let x2 = [1, -0.15];
let x3 = [1, -0.15];
let t = 0;
let h = 0.2;
const stream = fs.createWriteStream("out.csv");
while (t <= 20) {
    stream.write(`${t}, ${x1[0]}, ${x2[0]}, ${x3[0]}, ${trueValue(t)}\n`);
    [x1] = rkf45(f, x1, t, h, false);
    [x2] = dopri5(f, x2, t, h, false);
    [x3] = dopri853(f, x3, t, h, false);
    t += h;
}
stream.end();

各ルンゲクッタでの計算値と真値の差分を対数でグラフ化する。

やはりdopri853がすごい。

刻み幅調整

比較のためどのメソッドも、atol: 1e-6, rtol: 1e-3で実行する。

rkf45


t=20に達するまで223点

dopri5


t=20に達するまで22点

dopri853


t=20に達するまで11点

古典的ルンゲクッタばかり使っていた身としては、hが1以上でも誤差がここまで抑えられることが恐ろしい。

脚注
  1. この本はとてもよくまとまっているが現在入手不可能。原著の「Solving Ordinary Differential Equations I」は入手できるが、こういう良著が手軽に日本語で読めなくなったあたり厳しい時代になってしまった。 ↩︎

  2. scipyではReferenceとしてコメントに書いているだけだが、ライセンスを厳密に運用したい人はライセンス表記に加えるとよいだろう。 ↩︎

Discussion