2️⃣

[Rust] 二重数による高階自動微分

2023/08/20に公開

次回

どうも、最近は自動微分沼にはまっています。

前前回前回ではリバースモードによる自動微分を Rust で実装しました。これらは実装が簡単である分、高階微分に対応していないという欠点があります。機械学習では1次の微分で済む場合が多いですが、一般の最適化問題には2次微分を使うものもありますので、何とかサポートしたいところです。

高階自動微分を行う論文はそれなりにあるのですが、理論的難易度は急に上がる感じがします。また解説しているブログや技術記事も極端に少ないです。
方針としては私が調べた限り3通りあります。

  1. ノードごとの高階微分のルールを手で計算する
  2. 二重数を拡張する
  3. 微分した結果を返すサブグラフを動的に生成する

本稿では 1. と 2. を扱います。 1. は現実的にスケーラブルな解法とはいえませんので、実質的に 2. の話になります。 3. は長くなりそうなので次に回します。

手で高階微分のルールを書き下す方法

これは最初に思いつく方法だと思いますが、階数が有限なら根気で何とかならないこともないです。例えば、積の微分のルールは次のように計算できます。

\frac{d}{dx}fg = f\frac{dg}{dx} + \frac{df}{dx}g

これに再度微分操作をすることによって2次微分になります。

\begin{align} \frac{d^2}{dx^2} fg &= \frac{d}{dx} \left( f\frac{dg}{dx} + \frac{df}{dx}g \right) \\ &= \frac{d^2 f}{dx^2} g + 2 \frac{df}{dx} \frac{dg}{dx} + f\frac{d^2 g}{dx^2} \end{align}

Rust であれば次のように2次微分のルールが書けます。ここで2回以上現れる項は Copy とは限りませんので最後の一回の参照を除いて clone() していることに注意してください。

Mul(lhs, rhs) => {
    let dlhs = derive(nodes, lhs, wrt)?;
    let drhs = derive(nodes, rhs, wrt)?;
    let d2lhs = derive2(nodes, lhs, wrt)?;
    let d2rhs = derive2(nodes, rhs, wrt)?;
    let vrhs = value(nodes, rhs)?;
    let vlhs = value(nodes, lhs)?;
    let cross = dlhs * drhs;
    d2lhs * vrhs + vlhs * d2rhs + cross.clone() + cross
}

これでガウス分布が次のように微分できます。

fn build_model(tape: &Tape) -> (TapeTerm, TapeTerm) {
    let x = tape.term("x", 0.);
    let exp_x2 = (-(x * x)).apply("exp", f64::exp, f64::exp, f64::exp);
    (x, exp_x2)
}

let tape = Tape::new();
let (x, all) = build_model(&tape);
let derive2 = all.derive2(&x).unwrap();

数式を手で計算した結果は次のようになります。

\begin{align*} f(x) &= \exp(-x^2) \\ \frac{df}{dx} &= -2x \exp(-x^2) \\ \frac{d^2 f}{dx} &= (-2 + 4x) \exp(-x^2) \end{align*}

しかし、この方法は明らかにスケールしませんし、フォワードモードでしか動作しませんし、多変数関数の場合にも拡張できませんし、何よりも自動感がありません。

二重数による高階微分

ここで Higher Order Automatic Differentiation with Dual Numbers
という論文を見つけました。他の論文のように圏論とかは出てこないので取っつきやすいのと、 C++ の実装が紹介されているのが良いところです。

結論から言うと次のように Rust で Dvec という型を使って高階自動微分ができるようになりました。

        let d1 = Dvec::new_n(x, 1., 3);
        let d2 = &d1 * &d1;
        let d3 = -&d2;
        let d4 = d3.apply(|x, _| x.exp());
        let res = d4;
        writeln!(f, "{x}, {}, {}, {}, {}", res[0], res[1], res[2], res[3]).unwrap();

しかし、二重数は原理的にフォワードモードでしか動作しないので、入力変数が多い時にはリバースモードにしたくてもできません。そんなときは冒頭で触れた 3. の方針で行くしかないと思います。

理論

ここからは二重数の出番です。二重数とは、複素数のように2つの成分を持ち、 x + y\varepsilon のように書かれます。この \varepsilon は仮想的な数で、無限小のようなものを表していると考えます。無限小なので二乗すると2次の微小量となって消滅します。 \varepsilon^2 = 0

2重数を使うと積が次のように書けます。

(f + f'\varepsilon) (g + g'\varepsilon) = fg + (fg' + gf') \varepsilon + \varepsilon^2

\varepsilon の係数を抽出すれば導関数になっています。

fg' + gf'

これは次の微分法則に対応しています。

d(fg) = f\, dg + g\, df

二重数とは、次の微分の定義の中に現れる微小量を記号的に扱ったものにすぎません。

f'(x) = \lim_{\varepsilon \rightarrow 0} \frac{f(x + \varepsilon) - f(x)}{\varepsilon}

二重数を拡張することで高階微分にも対応できます。この場合 x + a\varepsilon + b\eta のように2つの仮想的な量の単位を導入します。これはそれぞれ2回適用した微分の微小量です。

f''(x) = \lim_{\eta \rightarrow 0} \frac{ \lim_{\varepsilon \rightarrow 0} \frac{f(x + \varepsilon + \eta) - f(\varepsilon + \eta)}{\varepsilon} - f(x) }{\eta}

このような二重数(三重数?)を使って積の法則を計算すると次のようになります。

(x + a\varepsilon + m\eta) (y + b\varepsilon + n\eta) = \\ xy + (ya + xb) \varepsilon + (ym + xn) \eta + (yam + xbn) \varepsilon + ab\varepsilon^2 + mn\eta^2

ここで、 \varepsilon\eta は別々の微小量を表すので、非対称な変換側を持ちます(というか、微分の法則に合わせるようにこれらの変換側を定めます)。

  • \varepsilon^2 = 2\eta (\frac{d}{dx} \frac{d}{dx} に対応)
  • \varepsilon \eta = 0 (3次以上の微小量になるので無視)
  • \eta^2 = 0 (さらに高次の微小量になるので無視)

これらを使うと最終的に

xy + (ya + xb) \varepsilon + (ym + xn + 2ab) \eta

となります。ここで \eta の係数が2次の微分係数になります。

ym + xn + 2ab

よく見ると式 (2) と合っていることがわかります。

論文では行列で演算規則を考える方法も紹介されています。演算規則の理解のためには役に立ちますが、実際のプログラムでこのような行列計算で演算するのはメモリの無駄なのでないと思います。

1 = \begin{pmatrix} 1 & 0 & 0 \\ 0 & 1 & 0 \\ 0 & 0 & 1 \end{pmatrix}, \varepsilon = \begin{pmatrix} 0 & 2 & 0 \\ 0 & 0 & 2 \\ 0 & 0 & 0 \end{pmatrix}, \eta = \begin{pmatrix} 0 & 0 & 2 \\ 0 & 0 & 0 \\ 0 & 0 & 0 \end{pmatrix}

論文ではさらに一般化して N 次の二重数を次のように定義します。

\mathcal{D}(f) = \sum_{j=0}^N f^{(j)} \mathbf{i}_j

加減算のルールは簡単です。単にそれぞれの要素を加算するだけです。

\mathcal{D}(f_1 \pm f_2) = \sum_{j=0}^N (f_1^{(j)} + f_2^{(j)}) \mathbf{i}_j = \mathcal{D}(f_1) \pm \mathcal{D}(f_2)

掛け算のルールはもう少し複雑です。論文では色々計算して最終的に次のルールを導いています。

\mathcal{D}(f_1 f_2) = \sum_{j=0}^N \sum_{k=0}^{N-j} f_1^{(j)} f_2^{(k)} \begin{pmatrix} j + k \\ k \end{pmatrix} \mathrm{i}_{j+k}

割り算は分母を実数にすることで掛け算に帰着できます。そのためには、次のように分母・分子に分母を掛けるのですが、これだけだと分母が実数になるとは限りません。しかし、分母に出現する虚数部は全て \mathrm{i}_j \mathrm{i}_k といった積になるので、次数が一つ高次側に「押しやられ」ます。これを繰り返すことで N 回以内に分母が実数になります。

\begin{align*} \mathcal{D}\left(\frac{f_1}{f_2}\right) &= \frac{f_1 + \sum_{j=1}^N f^{(j)} \mathrm{i}_j}{f_2 + \sum_{j=1}^N f^{(j)} \mathrm{i}_j} \\ &= \frac{\left(f_1 + \sum_j f_1^{(j)} \mathrm{i}_j\right)\left(f_2 - \sum_j f_2^{(j)} \mathrm{i}_j\right)} {\left(f_2 + \sum_j f_2^{(j)} \mathrm{i}_j\right) \left(f_2 - \sum_j f_2^{(j)} \mathrm{i}_j\right)} \\ &= \frac{\left(f_1 + \sum_j f^{(j)} \mathrm{i}_j\right)\left(f_2 - \sum_j f_2^{(j)} \mathrm{i}_j\right)} {f_2^2 - \left(\sum_j f_2^{(j)} \mathrm{i}_j\right)^2} \end{align*}

二項係数を使った実装

これらのルールをコードに書き換えるのに、論文では2通りの方針を示しています。基本的に複雑なのは掛け算なので、掛け算への対処法で方針が分かれます。一つは2重数の積の成分を二項係数で求める方法です。

これは数式の直接の解釈になるので理論から実装への変換という意味では理解しやすいですが、少なくとも論文で示されている C++ 実装では任意の次数の対応はできず、ハードコードされた次数に対応することしかできません(テンプレートを使えばコンパイル時定数の次数になら任意に対応できると思いますが、なぜかそうしていません)。

Rust での実装はこちらにあります。データ構造は単純に二重数の配列を構造体化したものになります。次数は const generics です。

pub struct Dnum<const N: usize> {
    f: [f64; N], // value and derivatives
}

足し算と引き算は自明なので省略しますが、掛け算は次のようになります。

impl<const N: usize> std::ops::Mul for Dnum<N> {
    type Output = Self;
    fn mul(self, rhs: Self) -> Self::Output {
        let mut f = [0.; N];
        for j in 0..N {
            for k in 0..N {
                if j + k < N {
                    f[j + k] += self.f[j] * rhs.f[k] * choose(j + k, j) as f64;
                }
            }
        }
        Self { f }
    }
}

割り算も理論通りの分母を実数にする方法で実装しています。繰り返し計算に再帰を使っているところと、 rhs * rhs.conjugate() はもう少し効率的に計算できそうだとは思いますが、論文からの忠実な翻訳です。ここで、実数の場合は単純に成分ごとの割り算になるので別の impl ブロックを使っているのに注目してください。

impl<const N: usize> std::ops::Div<f64> for Dnum<N> {
    type Output = Self;
    fn div(self, rhs: f64) -> Self::Output {
        let mut f = self.f;
        f.iter_mut().for_each(|j| *j = *j / rhs);
        Self { f }
    }
}

impl<const N: usize> std::ops::Div for Dnum<N> {
    type Output = Self;
    fn div(self, rhs: Self) -> Self::Output {
        if rhs.is_real() {
            self / rhs.f[0]
        } else {
            let crhs = rhs.conjugate();
            (self * crhs) / (rhs * crhs)
        }
    }
}

choose は2項係数です。オーバーフローがちょっと心配ですが次のような乗算の実装にしています。

\binom{n}{k} = \prod_{i=0}^{k-1} \frac{n - i}{i + 1}
pub(crate) fn choose(n: usize, k: usize) -> usize {
    assert!(k <= n);
    let mut res = 1;
    for i in 0..k {
        res *= n - i
    }
    for i in 1..=k {
        res /= i;
    }
    res
}

ただしこの方法には欠点があり、任意の関数の微分を行うにはその高階微分の実装も自前で行う必要があります。論文では2次までの exp や sin や pow の実装が示されていますが、一般にはこれらの関数を合成したときの高次の微分には係数に多項式が含まれます。この計算には次の再帰呼び出しを使った実装のほうが簡単です。

再帰呼び出しを使った実装

もう一つの方針は再帰的に処理するというものです。二項係数を使った方法との大きな違いは、動的なサイズに対応しており、一つの型が任意の次数の微分を扱えるということです。ただし、動的メモリを使うためパフォーマンス面では次数が固定であれば二項係数を使ったほうが有利でしょう。

Rust での実装はこちらです。

重要になるのは微分係数のレベルを落とす演算 D と最後を削る F という演算です。二重数の係数を配列に格納すると、先頭と末尾の両方に追加・削除が行われるので、データ構造としては deque (Double ended queue) を使います。 Rust では VecDeque と呼ばれている型です。

これを使って Rust での型は次のようになります。

pub struct Dvec<T = f64>(VecDeque<T>);

演算 D および F は次のようになります。

impl<T: Tensor> Dvec<T> {
    fn F(&self) -> Self {
        // Front operator
        let mut ffront = self.clone();
        ffront.0.pop_back();
        ffront
    }

    fn D(&self) -> Self {
        // Derivation operator
        let mut fback = self.clone();
        fback.0.pop_front();
        fback
    }
}

さらにここが肝になるところですが、積の場合は再帰的に項を微分して高階微分を求めます。

impl<T: Tensor> std::ops::Mul for &Dvec<T> {
    type Output = Dvec<T>;
    fn mul(self, rhs: &Dvec<T>) -> Self::Output {
        if self.is_real() || rhs.is_real() {
            single(self.0[0].clone() * rhs.0[0].clone())
        } else {
            Dvec::new(
                self.0[0].clone() * rhs.0[0].clone(),
                &(&self.D() * &rhs.F()) + &(&self.F() * &rhs.D()),
            )
        }
    }
}

積ができれば、割り算は理論で述べたように変換できますので、次のように書けます。

impl<T: Tensor> std::ops::Div for Dvec<T> {
    type Output = Dvec<T>;
    fn div(self, rhs: Dvec<T>) -> Self::Output {
        if self.is_real() || rhs.is_real() {
            single(self.0[0].clone() / rhs.0[0].clone())
        } else {
            Dvec::new(
                self.0[0].clone() / rhs.0[0].clone(),
                (&(&self.d() * &rhs) - &(&self * &rhs.d())) / (&rhs * &rhs),
            )
        }
    }
}

任意の関数を適用するメソッドも用意します。ここで、なぜか論文では示されていませんが、再帰呼び出しを使えば任意の次数の微分も簡単に書けます。ただし、関数自体の高階微分は計算できる必要があります。引数の f は値と微分の次数を引数に取る関数ポインタです。

impl<T: Tensor> Dvec<T> {
    pub fn apply(&self, f: fn(T, usize) -> T) -> Self {
        self.apply_rec(f, 0)
    }

    fn apply_rec(&self, f: fn(T, usize) -> T, n: usize) -> Self {
        if self.is_real() {
            single(f(self.0[0].clone(), n))
        } else {
            Self::new(
                f(self.0[0].clone(), n),
                &self.f().apply_rec(f, n + 1) * &self.d(),
            )
        }
    }
}

論文では関数とその1次、2次微分を別々の引数で渡していましたが、ここでは一般化して微分階数を引数に取る関数を使います。例えば三角関数であれば次のようなものです。

d1.apply(|x, n| {
    match n % 4 {
    0 => x.sin(),
    1 => x.cos(),
    2 => -x.sin(),
    3 => -x.cos(),
    _ => unreachable!(),
    }
});

試しにガウス関数を高次微分してみます。

高階多変数微分

高階微分の中でも、変数の数が2以上になってくると、二重数の実装は複雑になってきます。交差項 \partial x \partial y などを区別する必要が出てくるからです。論文では2変数までの実装が示されています。

Rust での実装はこちらです。

この実装では高次の交差項を覚えておくために三角行列を使います。しかし残念ながら Rust では const generic の値は配列の大きさを式で表すのには使えません。つまり次のようには書けません。

pub struct Dnum2<const N: usize> {
    f: [f64; (N + 1) * (N + 2) / 2],
}

動的メモリを使いたくはないのですが、 Vec を使わざるを得ません。

pub struct Dnum2<const N: usize> {
    f: Vec<f64>, // Dynamic array instead of fixed, since Rust can't have const expression in array size
}

この実装では掛け算はなんと4重ループになります。

impl<const N: usize> std::ops::Mul for Dnum2<N> {
    type Output = Self;
    fn mul(self, rhs: Self) -> Self::Output {
        let mut f = Self {
            f: vec![0.; Self::N2],
        };
        for j in 0..=N {
            for l in 0..=N - j {
                for k in 0..=N {
                    for m in 0..=N - k {
                        if j + k + l + m < N {
                            f[(j + k, l + m)] += self[(j, l)]
                                * rhs[(k, m)]
                                * choose(j + k, j) as f64
                                * choose(l + m, l) as f64;
                        }
                    }
                }
            }
        }
        f
    }
}

任意の関数の微分は論文でも面倒になったのか、 exp や sin といった関数のものすら示されていません。

まとめ

高階自動微分のアプローチとして二重数の拡張を実装してみました。二重数はメモリ上にツリーを構築する必要がないなどの利点があります。しかしながら、多変数への拡張は簡単ではなく、任意の関数を式の途中に挟むのも高次微分の計算を自分でする必要があって、あまりご利益感がありません。多変数のヘッシアンを求めたいと思ったら結構大変です。

一般的にはサブグラフを生成する方法が柔軟性に優れ、好きな次数の好きな変数の微分を計算できる点で有利だと思います。これについてはまた次回。

サブグラフを生成する方法が速度的に不満だったら二重数を見直してみるのもよいのではないでしょうか。

Discussion