〽️
高階自動微分をスクラッチ実装で理解する
概要
この章では、本記事の目的と読むことで得られることを簡潔にまとめます。
本記事で学べること
- 「高階自動微分(高階ジェット/Taylor 係数)」の数学的背景
- OCaml での最小実装(high_order_dual.ml)を読み解くことで、級数ベースの自動微分をスクラッチで理解する
- 畳み込み(Cauchy 積)を使った乗算実装、exp/sin/cos/log の再帰計算、数値的・計算量的な扱い
はじめに
- 自動微分(AD)は機械学習(勾配計算)、感度解析、数値最適化、数値解法(ODE/PDE)など幅広い分野で使われます。
- 古典的な AD(forward/reverse)は 1 次微分にフォーカスすることが多いですが、高階導関数や Taylor 展開が必要な場面(多項式近似、高次の感度解析、局所近似)もあります。
- 本記事では「Taylor 係数を(a_n = f^{(n)}(x0) / n! の形で)直接扱う」アプローチを取り、OCaml 実装を通じて原理とアルゴリズムを理解します。
テイラー級数と自動微分の数学的基礎
-
関数 f の点 x0 の周りのテイラー展開は
f(x0 + t) = sum_{n=0}^\infty a_n t^n, ここで a_n = f^{(n)}(x0) / n! -
a_n の形で導関数情報を持つ利点
- 階乗で割ることで、積や合成の係数計算が combinatorial な係数(nCk 等)を明示的に扱わずに扱える場合が多い(畳み込みや再帰式がシンプルになる)。
- 実装上、n! を毎回扱わずに済むため数値計算が整然とする。
-
なぜこの正規化が都合が良いか(簡単な説明)
- 例えば (f g) の n 次係数は sum_{i=0..n} a_i b_{n-i}(Cauchy 積)となり、Leibniz の係数が既に a_i, b_j の定義に吸収されている。
畳み込み積で導関数が計算できる理由
- Cauchy 積(畳み込み)の定義(冪級数の積):
(A(t) B(t))_n = sum_{i=0}^n a_i b_{n-i} - Leibniz の法則(積の n 次導関数)を a_n = f^{(n)} / n! の正規化で表すと、上の単純な畳み込み結果と一致する。
- 例:
- f(x) = sin x, g(x) = exp x の積をテイラー係数で求めると、乗算は畳み込みで済む。具体的な係数計算手順を一つずつ追うことで理解が深まります。
可視化(例)
- n = 2 の場合、
(fg)''(x0)/2! = a_0 b_2 + a_1 b_1 + a_2 b_0(項ごとに対応することを確認)
基本演算の実装
ここでは high_order_dual.ml の実装を参照しつつ、主要関数を見ていきます。
加算とスカラー倍
- 加算は各係数ごとの加算で正当化されます(線形性)。
- 実装(概念):
(* 要素ごとの加算(k は truncation order)*)
let add_series k a b =
let a = pad k 0.0 a in
let b = pad k 0.0 b in
List.map2 (+.) a b
- pad により長さを k+1 に揃え、単純な要素和を取るだけです。
乗算 - 畳み込みの核心
実装(抜粋)
let mul_series k a b =
let a_arr = Array.of_list (pad k 0.0 a) in
let b_arr = Array.of_list (pad k 0.0 b) in
let c = Array.make (k+1) 0.0 in
for n = 0 to k do
let sum = ref 0.0 in
for i = 0 to n do
sum := !sum +. a_arr.(i) *. b_arr.(n - i)
done;
c.(n) <- !sum
done;
Array.to_list c
なぜこのループで積の高階導関数が計算できるのか
- テイラー級数表現 f(t) = sum a_i t^i, g(t) = sum b_j t^j とすると、積の係数 c_n は sum_{i=0..n} a_i b_{n-i}(Cauchy 積)であるため、上の二重ループはそのまま係数を計算しています。
- a_i, b_j が既に f^{(i)}/i!, g^{(j)}/j! の正規化になっている点が重要です。
計算量の考察
- naive な畳み込みは O(k^2) の演算回数(k は truncation order)。
- k が大きくなる場合は FFT ベースの畳み込み(多項式乗算、ただし精度・丸め処理の配慮が要る)や階層的手法を検討できます。
初等関数の再帰公式の導出
指数関数 exp の場合
- 微分方程式 f' = f を冪級数で書くと、係数に対して再帰関係が導かれます。
- 実装(抜粋):
let exp_series k a =
let a_arr = Array.of_list (pad k 0.0 a) in
let b = Array.make (k+1) 0.0 in
b.(0) <- exp a_arr.(0);
for n = 1 to k do
let s = ref 0.0 in
for k1 = 1 to n do
s := !s +. (float_of_int k1) *. a_arr.(k1) *. b.(n - k1)
done;
b.(n) <- !s /. float_of_int n
done;
Array.to_list b
再帰関係式の意図(簡単な導出)
-
と置くと、A(t) = sum a_n t^n, B(t) = exp(A(t)) = sum b_n t^n 。B' = A' * B - 各係数を比べて整理すると、
, そしてb_0 = exp(a_0)
b_n = (1/n) sum_{k=1..n} k * a_k * b_{n-k}
が得られます(実装と一致)。
三角関数 sin/cos の場合
- sin/cos は互いに微分で結びついているため、同時計算で再帰的に求めると効率が良いです。
- 実装(抜粋):
let sin_cos_series k a =
let a_arr = Array.of_list (pad k 0.0 a) in
let s = Array.make (k+1) 0.0 in
let c = Array.make (k+1) 0.0 in
s.(0) <- sin a_arr.(0);
c.(0) <- cos a_arr.(0);
for n = 1 to k do
let ss = ref 0.0 in
let cs = ref 0.0 in
for k1 = 1 to n do
let term = (float_of_int k1) *. a_arr.(k1) in
ss := !ss +. term *. c.(n - k1);
cs := !cs +. term *. s.(n - k1)
done;
s.(n) <- !ss /. float_of_int n;
c.(n) <- -. (!cs) /. float_of_int n
done;
(Array.to_list s, Array.to_list c)
相互再帰関係の数学的背景
-
やS' = C * A'(A は引数の系列) といった微分関係から係数の再帰式が導出されます。C' = -S * A' - sin と cos を同時に求めることで、内部ループの計算を共有でき、精度・効率の面で有利です。
対数関数 log の場合
- log は逆関数的な扱いになり、
から再帰式を導出します。A'(t) = A(t) * B'(t)(ここで B = log A) - 実装(抜粋、バグ修正済み):
let log_series k a =
let a_lst = pad k 0.0 a in
let a0 = List.nth a_lst 0 in
if a0 <= 0.0 then invalid_arg "log_series: a0 must be > 0";
let a_arr = Array.of_list a_lst in
let b = Array.make (k+1) 0.0 in
b.(0) <- log a0;
for n = 1 to k do
let s = ref 0.0 in
for k1 = 1 to n - 1 do
s := !s +. (float_of_int k1) *. b.(k1) *. a_arr.(n - k1)
done;
let n_float = float_of_int n in
b.(n) <- (n_float *. a_arr.(n) -. !s) /. (n_float *. a0)
done;
Array.to_list b
導出の要点(概略)
-
を係数ごとに比較して整理すると、A' = A B' が得られ、これを解いて b_n を決めています。a0 * n * b_n = n * a_n - sum_{k=1..n-1} k * b_k * a_{n-k} - a0 が 0 に近い/負の場合の扱い(複素対応や branch)については注意が必要です。
収束性と数値的安定性
- この実装は冪級数(局所)表現なので、元の関数が展開点の近傍で解析的であることが前提。
- 高階になると丸め誤差や階乗のスケールで数値的に不安定になりやすい(特に浮動小数点で K を大きくしたとき)。
- 安定化のための方策
- 任意精度(Bigfloat)での実装
- 冪級数の保持形式やスケーリング(例えば中心化/正規化)
- 係数の再正規化や誤差伝搬の評価
パフォーマンス改善のポイント
- リスト操作は便利だが遅い部分があるため、内側ループの集約に配列を使っている(high_order_dual.ml の mul_series など)。
- 再帰計算は結果が後の係数で再利用されるため、メモ化(ここでは配列に格納)が効いている。
- アルゴリズム的改善
- 自然に O(k^2) だが、大きな k の場合は高速多項式乗算(FFT)を考慮する。
- 並列化(各 n の計算は一見依存があるが、ブロック分割や階層化で並列化可能)。
実装の応用と検証
- 具体例:
の k 階までの係数計算(high_order_dual.ml の example を活用)f(x) = sin(x) * exp(x) - 高階導関数の抽出方法:a_n を取り出し、
を計算する関数の紹介(実装あり)f^{(n)}(x0) = a_n * n! - 検証方法
- 既知関数(多項式、exp, sin, cos, log の合成)で数値的に比較
- 中心差分など単純な数値微分との比較(小さな n で一致チェック)
- 単位テスト、ランダム化された小範囲の点での一致確認
発展的な話題
- 他の初等関数への拡張(tan, asin, acos, pow, 複合関数の一般的な合成則)
- 複素数対応
- 任意精度(MPFR/Bigfloat)
- 実際の AD ライブラリとの比較(性能、機能、使い勝手)
- 応用例:高次 Taylor 展開を用いた ODE ステッパや、Newton 法の高次拡張
まとめ
- テイラー級数(a_n = f^{(n)}/n!)を基底にした高階自動微分は、数学的に明快で教育的価値が高い。
- 実装をスクラッチで追うことで、畳み込み・再帰式・数値安定性・計算量など AD の本質を深く理解できる。
- high_order_dual.ml は学習用として必要最小限がしっかりまとまっており、拡張/最適化の出発点として適切。
参考文献・リソース
- (実装参照)high_order_dual.ml(このリポジトリの該当ファイル)
- 「自動微分に関する入門書・論文」:基本的な AD の説明が載っている教科書やレビュー記事
- 冪級数・級数操作に関する数学書(解析学のテイラー展開の章など)
- FFT を用いた多項式乗算に関する資料(アルゴリズム最適化のため)
付録:重要なコードスニペット(high_order_dual.ml から抜粋)
- 畳み込み(乗算)
let mul_series k a b =
let a_arr = Array.of_list (pad k 0.0 a) in
let b_arr = Array.of_list (pad k 0.0 b) in
let c = Array.make (k+1) 0.0 in
for n = 0 to k do
let sum = ref 0.0 in
for i = 0 to n do
sum := !sum +. a_arr.(i) *. b_arr.(n - i)
done;
c.(n) <- !sum
done;
Array.to_list c
- exp の級数
let exp_series k a =
let a_arr = Array.of_list (pad k 0.0 a) in
let b = Array.make (k+1) 0.0 in
b.(0) <- exp a_arr.(0);
for n = 1 to k do
let s = ref 0.0 in
for k1 = 1 to n do
s := !s +. (float_of_int k1) *. a_arr.(k1) *. b.(n - k1)
done;
b.(n) <- !s /. float_of_int n
done;
Array.to_list b
- sin, cos の同時計算
let sin_cos_series k a =
let a_arr = Array.of_list (pad k 0.0 a) in
let s = Array.make (k+1) 0.0 in
let c = Array.make (k+1) 0.0 in
s.(0) <- sin a_arr.(0);
c.(0) <- cos a_arr.(0);
for n = 1 to k do
let ss = ref 0.0 in
let cs = ref 0.0 in
for k1 = 1 to n do
let term = (float_of_int k1) *. a_arr.(k1) in
ss := !ss +. term *. c.(n - k1);
cs := !cs +. term *. s.(n - k1)
done;
s.(n) <- !ss /. float_of_int n;
c.(n) <- -. (!cs) /. float_of_int n
done;
(Array.to_list s, Array.to_list c)
- log の級数(バグ修正版)
let log_series k a =
let a_lst = pad k 0.0 a in
let a0 = List.nth a_lst 0 in
if a0 <= 0.0 then invalid_arg "log_series: a0 must be > 0";
let a_arr = Array.of_list a_lst in
let b = Array.make (k+1) 0.0 in
b.(0) <- log a0;
for n = 1 to k do
let s = ref 0.0 in
for k1 = 1 to n - 1 do
s := !s +. (float_of_int k1) *. b.(k1) *. a_arr.(n - k1)
done;
let n_float = float_of_int n in
b.(n) <- (n_float *. a_arr.(n) -. !s) /. (n_float *. a0)
done;
Array.to_list b
#zennfes2025free
Discussion