〽️

高階自動微分をスクラッチ実装で理解する

に公開

概要

この章では、本記事の目的と読むことで得られることを簡潔にまとめます。

本記事で学べること

  • 「高階自動微分(高階ジェット/Taylor 係数)」の数学的背景
  • OCaml での最小実装(high_order_dual.ml)を読み解くことで、級数ベースの自動微分をスクラッチで理解する
  • 畳み込み(Cauchy 積)を使った乗算実装、exp/sin/cos/log の再帰計算、数値的・計算量的な扱い

はじめに

  • 自動微分(AD)は機械学習(勾配計算)、感度解析、数値最適化、数値解法(ODE/PDE)など幅広い分野で使われます。
  • 古典的な AD(forward/reverse)は 1 次微分にフォーカスすることが多いですが、高階導関数や Taylor 展開が必要な場面(多項式近似、高次の感度解析、局所近似)もあります。
  • 本記事では「Taylor 係数を(a_n = f^{(n)}(x0) / n! の形で)直接扱う」アプローチを取り、OCaml 実装を通じて原理とアルゴリズムを理解します。

テイラー級数と自動微分の数学的基礎

  • 関数 f の点 x0 の周りのテイラー展開は

    f(x0 + t) = sum_{n=0}^\infty a_n t^n, ここで a_n = f^{(n)}(x0) / n!

  • a_n の形で導関数情報を持つ利点

    • 階乗で割ることで、積や合成の係数計算が combinatorial な係数(nCk 等)を明示的に扱わずに扱える場合が多い(畳み込みや再帰式がシンプルになる)。
    • 実装上、n! を毎回扱わずに済むため数値計算が整然とする。
  • なぜこの正規化が都合が良いか(簡単な説明)

    • 例えば (f g) の n 次係数は sum_{i=0..n} a_i b_{n-i}(Cauchy 積)となり、Leibniz の係数が既に a_i, b_j の定義に吸収されている。

畳み込み積で導関数が計算できる理由

  • Cauchy 積(畳み込み)の定義(冪級数の積):
    (A(t) B(t))_n = sum_{i=0}^n a_i b_{n-i}
  • Leibniz の法則(積の n 次導関数)を a_n = f^{(n)} / n! の正規化で表すと、上の単純な畳み込み結果と一致する。
  • 例:
    • f(x) = sin x, g(x) = exp x の積をテイラー係数で求めると、乗算は畳み込みで済む。具体的な係数計算手順を一つずつ追うことで理解が深まります。

可視化(例)

  • n = 2 の場合、(fg)''(x0)/2! = a_0 b_2 + a_1 b_1 + a_2 b_0(項ごとに対応することを確認)

基本演算の実装

ここでは high_order_dual.ml の実装を参照しつつ、主要関数を見ていきます。

加算とスカラー倍

  • 加算は各係数ごとの加算で正当化されます(線形性)。
  • 実装(概念):
(* 要素ごとの加算(k は truncation order)*)
let add_series k a b =
  let a = pad k 0.0 a in
  let b = pad k 0.0 b in
  List.map2 (+.) a b
  • pad により長さを k+1 に揃え、単純な要素和を取るだけです。

乗算 - 畳み込みの核心

実装(抜粋)

let mul_series k a b =
  let a_arr = Array.of_list (pad k 0.0 a) in
  let b_arr = Array.of_list (pad k 0.0 b) in
  let c = Array.make (k+1) 0.0 in
  for n = 0 to k do
    let sum = ref 0.0 in
    for i = 0 to n do
      sum := !sum +. a_arr.(i) *. b_arr.(n - i)
    done;
    c.(n) <- !sum
  done;
  Array.to_list c

なぜこのループで積の高階導関数が計算できるのか

  • テイラー級数表現 f(t) = sum a_i t^i, g(t) = sum b_j t^j とすると、積の係数 c_n は sum_{i=0..n} a_i b_{n-i}(Cauchy 積)であるため、上の二重ループはそのまま係数を計算しています。
  • a_i, b_j が既に f^{(i)}/i!, g^{(j)}/j! の正規化になっている点が重要です。

計算量の考察

  • naive な畳み込みは O(k^2) の演算回数(k は truncation order)。
  • k が大きくなる場合は FFT ベースの畳み込み(多項式乗算、ただし精度・丸め処理の配慮が要る)や階層的手法を検討できます。

初等関数の再帰公式の導出

指数関数 exp の場合

  • 微分方程式 f' = f を冪級数で書くと、係数に対して再帰関係が導かれます。
  • 実装(抜粋):
let exp_series k a =
  let a_arr = Array.of_list (pad k 0.0 a) in
  let b = Array.make (k+1) 0.0 in
  b.(0) <- exp a_arr.(0);
  for n = 1 to k do
    let s = ref 0.0 in
    for k1 = 1 to n do
      s := !s +. (float_of_int k1) *. a_arr.(k1) *. b.(n - k1)
    done;
    b.(n) <- !s /. float_of_int n
  done;
  Array.to_list b

再帰関係式の意図(簡単な導出)

  • A(t) = sum a_n t^n, B(t) = exp(A(t)) = sum b_n t^nと置くと、B' = A' * B
  • 各係数を比べて整理すると、b_0 = exp(a_0), そして
    b_n = (1/n) sum_{k=1..n} k * a_k * b_{n-k}
    が得られます(実装と一致)。

三角関数 sin/cos の場合

  • sin/cos は互いに微分で結びついているため、同時計算で再帰的に求めると効率が良いです。
  • 実装(抜粋):
let sin_cos_series k a =
  let a_arr = Array.of_list (pad k 0.0 a) in
  let s = Array.make (k+1) 0.0 in
  let c = Array.make (k+1) 0.0 in
  s.(0) <- sin a_arr.(0);
  c.(0) <- cos a_arr.(0);
  for n = 1 to k do
    let ss = ref 0.0 in
    let cs = ref 0.0 in
    for k1 = 1 to n do
      let term = (float_of_int k1) *. a_arr.(k1) in
      ss := !ss +. term *. c.(n - k1);
      cs := !cs +. term *. s.(n - k1)
    done;
    s.(n) <- !ss /. float_of_int n;
    c.(n) <- -. (!cs) /. float_of_int n
  done;
  (Array.to_list s, Array.to_list c)

相互再帰関係の数学的背景

  • S' = C * A'(A は引数の系列)C' = -S * A' といった微分関係から係数の再帰式が導出されます。
  • sin と cos を同時に求めることで、内部ループの計算を共有でき、精度・効率の面で有利です。

対数関数 log の場合

  • log は逆関数的な扱いになり、A'(t) = A(t) * B'(t)(ここで B = log A)から再帰式を導出します。
  • 実装(抜粋、バグ修正済み):
let log_series k a =
  let a_lst = pad k 0.0 a in
  let a0 = List.nth a_lst 0 in
  if a0 <= 0.0 then invalid_arg "log_series: a0 must be > 0";
  
  let a_arr = Array.of_list a_lst in
  let b = Array.make (k+1) 0.0 in
  b.(0) <- log a0;
  
  for n = 1 to k do
    let s = ref 0.0 in
    for k1 = 1 to n - 1 do
      s := !s +. (float_of_int k1) *. b.(k1) *. a_arr.(n - k1)
    done;
    let n_float = float_of_int n in
    b.(n) <- (n_float *. a_arr.(n) -. !s) /. (n_float *. a0)
  done;
  Array.to_list b

導出の要点(概略)

  • A' = A B' を係数ごとに比較して整理すると、a0 * n * b_n = n * a_n - sum_{k=1..n-1} k * b_k * a_{n-k} が得られ、これを解いて b_n を決めています。
  • a0 が 0 に近い/負の場合の扱い(複素対応や branch)については注意が必要です。

収束性と数値的安定性

  • この実装は冪級数(局所)表現なので、元の関数が展開点の近傍で解析的であることが前提。
  • 高階になると丸め誤差や階乗のスケールで数値的に不安定になりやすい(特に浮動小数点で K を大きくしたとき)。
  • 安定化のための方策
    • 任意精度(Bigfloat)での実装
    • 冪級数の保持形式やスケーリング(例えば中心化/正規化)
    • 係数の再正規化や誤差伝搬の評価

パフォーマンス改善のポイント

  • リスト操作は便利だが遅い部分があるため、内側ループの集約に配列を使っている(high_order_dual.ml の mul_series など)。
  • 再帰計算は結果が後の係数で再利用されるため、メモ化(ここでは配列に格納)が効いている。
  • アルゴリズム的改善
    • 自然に O(k^2) だが、大きな k の場合は高速多項式乗算(FFT)を考慮する。
    • 並列化(各 n の計算は一見依存があるが、ブロック分割や階層化で並列化可能)。

実装の応用と検証

  • 具体例:f(x) = sin(x) * exp(x) の k 階までの係数計算(high_order_dual.ml の example を活用)
  • 高階導関数の抽出方法:a_n を取り出し、f^{(n)}(x0) = a_n * n! を計算する関数の紹介(実装あり)
  • 検証方法
    • 既知関数(多項式、exp, sin, cos, log の合成)で数値的に比較
    • 中心差分など単純な数値微分との比較(小さな n で一致チェック)
    • 単位テスト、ランダム化された小範囲の点での一致確認

発展的な話題

  • 他の初等関数への拡張(tan, asin, acos, pow, 複合関数の一般的な合成則)
  • 複素数対応
  • 任意精度(MPFR/Bigfloat)
  • 実際の AD ライブラリとの比較(性能、機能、使い勝手)
  • 応用例:高次 Taylor 展開を用いた ODE ステッパや、Newton 法の高次拡張

まとめ

  • テイラー級数(a_n = f^{(n)}/n!)を基底にした高階自動微分は、数学的に明快で教育的価値が高い。
  • 実装をスクラッチで追うことで、畳み込み・再帰式・数値安定性・計算量など AD の本質を深く理解できる。
  • high_order_dual.ml は学習用として必要最小限がしっかりまとまっており、拡張/最適化の出発点として適切。

参考文献・リソース

  • (実装参照)high_order_dual.ml(このリポジトリの該当ファイル)
  • 「自動微分に関する入門書・論文」:基本的な AD の説明が載っている教科書やレビュー記事
  • 冪級数・級数操作に関する数学書(解析学のテイラー展開の章など)
  • FFT を用いた多項式乗算に関する資料(アルゴリズム最適化のため)

付録:重要なコードスニペット(high_order_dual.ml から抜粋)

  • 畳み込み(乗算)
let mul_series k a b =
  let a_arr = Array.of_list (pad k 0.0 a) in
  let b_arr = Array.of_list (pad k 0.0 b) in
  let c = Array.make (k+1) 0.0 in
  for n = 0 to k do
    let sum = ref 0.0 in
    for i = 0 to n do
      sum := !sum +. a_arr.(i) *. b_arr.(n - i)
    done;
    c.(n) <- !sum
  done;
  Array.to_list c
  • exp の級数
let exp_series k a =
  let a_arr = Array.of_list (pad k 0.0 a) in
  let b = Array.make (k+1) 0.0 in
  b.(0) <- exp a_arr.(0);
  for n = 1 to k do
    let s = ref 0.0 in
    for k1 = 1 to n do
      s := !s +. (float_of_int k1) *. a_arr.(k1) *. b.(n - k1)
    done;
    b.(n) <- !s /. float_of_int n
  done;
  Array.to_list b
  • sin, cos の同時計算
let sin_cos_series k a =
  let a_arr = Array.of_list (pad k 0.0 a) in
  let s = Array.make (k+1) 0.0 in
  let c = Array.make (k+1) 0.0 in
  s.(0) <- sin a_arr.(0);
  c.(0) <- cos a_arr.(0);
  for n = 1 to k do
    let ss = ref 0.0 in
    let cs = ref 0.0 in
    for k1 = 1 to n do
      let term = (float_of_int k1) *. a_arr.(k1) in
      ss := !ss +. term *. c.(n - k1);
      cs := !cs +. term *. s.(n - k1)
    done;
    s.(n) <- !ss /. float_of_int n;
    c.(n) <- -. (!cs) /. float_of_int n
  done;
  (Array.to_list s, Array.to_list c)
  • log の級数(バグ修正版)
let log_series k a =
  let a_lst = pad k 0.0 a in
  let a0 = List.nth a_lst 0 in
  if a0 <= 0.0 then invalid_arg "log_series: a0 must be > 0";
  
  let a_arr = Array.of_list a_lst in
  let b = Array.make (k+1) 0.0 in
  b.(0) <- log a0;
  
  for n = 1 to k do
    let s = ref 0.0 in
    for k1 = 1 to n - 1 do
      s := !s +. (float_of_int k1) *. b.(k1) *. a_arr.(n - k1)
    done;
    let n_float = float_of_int n in
    b.(n) <- (n_float *. a_arr.(n) -. !s) /. (n_float *. a0)
  done;
  Array.to_list b

#zennfes2025free
https://github.com/Yoshyhyrro/how_to_create_-/blob/Automatic-differentiation_implementation_experiment/high_order_dual.ml

よかったらコーヒーおごってください

Discussion