📈

Rustでフーリエ変換(FFT)

2022/01/15に公開

モチベーション

音声技術の勉強として、Rustで高速フーリエ変換(FFT)を実装。

また、FFT関連の記事は玄人向けが多かったため、理解の手助けとなるよう、FFT自体の説明も記載しました。
(行列演算やバタフライ演算の知識を省き、記事内の実装に近いカタチで記載しています)

シンプルなフーリエ変換の実装については、前回記事を参照ください。
https://zenn.dev/sadahiroyoshi/articles/e796dc3f4c05ab

概要

この記事では「FFT=Cooley-Tukey型FFTルーチン」とします。
(最も一般的と思われる型を選択。他は機会があれば。。)

Cooley-Tukey型FFTルーチンの高速化理論と、Rustでの実装について説明します。

Cooley-Tukey型FFTルーチン

Cooley-Tukey型FFTルーチンにおいて、フーリエ変換がどう高速化されるのかを計算式で説明します。

基本式

まずフーリエ変換の基本式。

\begin{equation} \begin{split} \begin{array}{cc} X_n = \displaystyle\sum_{k=0}^{N-1}x_kW_N^{kn} & W_N = e^{-i2{\pi}/N} \end{array} \end{split} \end{equation}
  • X_n: フーリエ変換結果(N次元)
  • n: フーリエ変換結果の次元
  • N: 入力フレームの個数(2の累乗(理由は後述))
  • x_k: 入力フレームの値
前回記事との差分

前回記事では、簡単のため「N=f_s(サンプリング周波数)」としていましたが、計算量を考慮し、フレーム数ベースに変更。
(前回実装したDTFTの計算量はO(N^2)。FFTでもO(N{\log}N)のため、サンプリング周波数=16000[Hz]だと膨大な処理に)

周波数[Hz]への変換

本記事でのフーリエ変換の結果は、周波数[Hz]単位ではなくN次元のベクトル。各ベクトルの周波数[Hz]は、n/Nf_s(サンプリング周波数)[Hz]を掛けて取得する必要あり。
(N次元なので〇〇Hzのデータがどうか、ではなく、f_s[Hz]をN個に分割した周波数[Hz]からデータをざっくり見ることになります(厳密にはナイキスト周波数を考慮するためn{\leqq}N/2のデータのみ有効)。また実務では、このベクトルへ、人間の音高知覚を考慮して周波数[Hz]を一定の次元数に圧縮する「メルフィルタバンク」を通し、音声特徴量「MFCC」へ加工するようです (詳しくはこちら))

現状のままだと、nごとにkを0からNまで計算する必要があり、計算回数はN^2回となります。

偶数部と奇数部にわける

計算量を削減するため、nが偶数であると仮定し、偶数部(X_{2m})と奇数部(X_{2m+1})の計算を別々に考えます。

mの範囲は以下になります。

\begin{equation} 0 {\leqq} m {\leqq} {N/2}-1 \end{equation}

以下の式を活用するので、先に示しておきます。

\begin{equation} \begin{split} W_N^N & = e^{-i2{\pi}/N \times N} = e^{-i2{\pi}} = 1 \\ \end{split} \end{equation}
\begin{equation} \begin{split} W_N^{N/2} & = e^{-i2{\pi}/N \times N/2} = e^{-i{\pi}} = -1 \\ \end{split} \end{equation}

( \because オイラーの等式より)

偶数部の計算

\SigmaN/2 までと N/2 からに分割し、(3) を利用しながら計算します。

\begin{equation} \begin{split} X_{2m} & = \displaystyle\sum_{k=0}^{N-1}x_kW_N^{2mk} \\ & = \displaystyle\sum_{k=0}^{{N/2}-1}x_kW_N^{2mk} + \displaystyle\sum_{k={N/2}}^{N-1}x_kW_N^{2mkW} \\ & = \displaystyle\sum_{k=0}^{{N/2}-1}x_kW_N^{2mk} + \displaystyle\sum_{k=0}^{{N/2}-1}x_{k+{N/2}}W_N^{2m(k+{N/2})} \\ & = \displaystyle\sum_{k=0}^{{N/2}-1}x_kW_N^{2mk} + \displaystyle\sum_{k=0}^{{N/2}-1}x_{k+{N/2}}W_N^{2mk+mN} \\ & = \displaystyle\sum_{k=0}^{{N/2}-1}x_kW_N^{2mk} + \displaystyle\sum_{k=0}^{{N/2}-1}x_{k+{N/2}}W_N^{2mk} \quad (\because W_N^{mN} = 1^m = 1 \enspace (3)) \\ & = \displaystyle\sum_{k=0}^{{N/2}-1}(x_k + x_{k+{N/2}})W_N^{2mk} \\ \end{split} \end{equation}

\Sigma の計算回数が N-1 から {N/2}-1 に削減されました!

奇数部の計算

偶数部と同様に、\SigmaN/2 までと N/2 からに分割し、(3),(4) を利用しながら計算します。

\begin{equation} \begin{split} X_{2m+1} & = \displaystyle\sum_{k=0}^{N-1}x_kW_N^{(2m+1)k} \\ & = \displaystyle\sum_{k=0}^{{N/2}-1}x_kW_N^{(2m+1)k} + \displaystyle\sum_{k={N/2}}^{N-1}x_kW_N^{(2m+1)k} \\ & = \displaystyle\sum_{k=0}^{{N/2}-1}x_kW_N^{(2m+1)k} + \displaystyle\sum_{k=0}^{{N/2}-1}x_{k+{N/2}}W_N^{(2m+1)(k+{N/2})} \\ & = \displaystyle\sum_{k=0}^{{N/2}-1}x_kW_N^{(2m+1)k} + \displaystyle\sum_{k=0}^{{N/2}-1}x_{k+{N/2}}W_N^{(2m+1)k+mN+{N/2}} \\ & = \displaystyle\sum_{k=0}^{{N/2}-1}x_kW_N^{(2m+1)k} + \displaystyle\sum_{k=0}^{{N/2}-1}-x_{k+{N/2}}W_N^{(2m+1)k} \quad (\because (3),(4)) \\ & = \displaystyle\sum_{k=0}^{{N/2}-1}(x_k - x_{k+{N/2}})W_N^{(2m+1)k} \\ \end{split} \end{equation}

奇数部も、\Sigma の計算回数が N-1 から {N/2}-1 に削減されました!

再帰呼び出しによる高速化

偶数部も奇数部も、計算回数を1/2に削減できました。

さらに計算回数を削減するためには、導き出した値をまた別のフーリエ変換と捉え、同様の計算式を再帰的に適用します。

まず偶数部(X_{2m})に適用します。説明のため、式を簡略化します。
(識別子を色付き(\color{purple}{\scriptscriptstyle 2m})で付与してます)

\begin{equation} \begin{split} \begin{array}{cc} X_{2m} = \displaystyle\sum_{k=0}^{{N/2}-1}x_k^{\color{purple}{\scriptscriptstyle 2m}}W_{N/2}^{mk} & x_k^{\color{purple}{\scriptscriptstyle 2m}} = x_k + x_{k+{N/2}} \\ \end{array} \end{split} \end{equation}

mを偶数と仮定し、偶数部(X_{2(2l))})と奇数部(X_{2(2l+1)})の計算を別々に考えます。

lの範囲は以下になります。

\begin{equation} 0 {\leqq} l {\leqq} N/4-1 \end{equation}

このlを元に偶数部(X_{2(2l))}),奇数部(X_{2(2l+1)})を計算します。

\begin{equation} \begin{split} X_{2(2l)} & = \displaystyle\sum_{k=0}^{N/2-1}x_k^{\color{purple}{\scriptscriptstyle 2m}}W_{N/2}^{2lk} \\ & = \displaystyle\sum_{k=0}^{{N/4}-1}(x_k^{\color{purple}{\scriptscriptstyle 2m}} + x_{k+{N/4}}^{\color{purple}{\scriptscriptstyle 2m}})W_{N/2}^{2lk} \\ X_{2(2l+1)} & = \displaystyle\sum_{k=0}^{N/2-1}x_k^{\color{purple}{\scriptscriptstyle 2m}}W_{N/2}^{(2l+1)k} \\ & = \displaystyle\sum_{k=0}^{{N/4}-1}(x_k^{\color{purple}{\scriptscriptstyle 2m}} - x_{k+{N/4}}^{\color{purple}{\scriptscriptstyle 2m}})W_{N/2}^{(2l+1)k} \\ \end{split} \end{equation}

偶数部(X_{2m})の計算回数が更に1/2になりました。

奇数部(X_{2m+1})も、W_N^{k}を分離して式を簡略化すると、同じ方法で実現可能です。

\begin{equation} \begin{split} \begin{array}{cc} X_{2m+1} = \displaystyle\sum_{k=0}^{N/2-1}x_k^{\color{purple}{\scriptscriptstyle 2m+1}}W_{N/2}^{mk} & x_k^{\color{purple}{\scriptscriptstyle 2m+1}} = (x_k - x_{k+N/2})W_N^{k} \\ \end{array} \end{split} \end{equation}

以下のような結果となります。

\begin{equation} \begin{split} X_{2(2l)+1} & = \displaystyle\sum_{k=0}^{N/2-1}x_k^{\color{purple}{\scriptscriptstyle 2m+1}}W_{N/2}^{2lk} \\ & = \displaystyle\sum_{k=0}^{{N/4}-1}(x_k^{\color{purple}{\scriptscriptstyle 2m+1}} + x_{k+{N/4}}^{\color{purple}{\scriptscriptstyle 2m+1}})W_{N/2}^{2lk} \\ X_{2(2l+1)+1} & = \displaystyle\sum_{k=0}^{N/2-1}x_k^{\color{purple}{\scriptscriptstyle 2m+1}}W_{N/2}^{(2l+1)k} \\ & = \displaystyle\sum_{k=0}^{{N/4}-1}(x_k^{\color{purple}{\scriptscriptstyle 2m+1}} - x_{k+{N/4}}^{\color{purple}{\scriptscriptstyle 2m+1}})W_{N/2}^{(2l+1)k} \\ \end{split} \end{equation}

そして、(9)(11)をまとめて整理すると以下になります。

\begin{equation} \begin{split} X_{4l} & = \displaystyle\sum_{k=0}^{{N/4}-1}(x_k^{\color{purple}{\scriptscriptstyle 2m}} + x_{k+{N/4}}^{\color{purple}{\scriptscriptstyle 2m}})W_{N/4}^{lk} \\ X_{4l+1} & = \displaystyle\sum_{k=0}^{{N/4}-1}(x_k^{\color{purple}{\scriptscriptstyle 2m+1}} + x_{k+{N/4}}^{\color{purple}{\scriptscriptstyle 2m+1}})W_{N/4}^{lk} \\ X_{4l+2} & = \displaystyle\sum_{k=0}^{{N/4}-1}(x_k^{\color{purple}{\scriptscriptstyle 2m}} - x_{k+{N/4}}^{\color{purple}{\scriptscriptstyle 2m}})W_{N/2}^{k}W_{N/4}^{lk} \\ X_{4l+3} & = \displaystyle\sum_{k=0}^{{N/4}-1}(x_k^{\color{purple}{\scriptscriptstyle 2m+1}} - x_{k+{N/4}}^{\color{purple}{\scriptscriptstyle 2m+1}})W_{N/2}^{k}W_{N/4}^{lk} \\ \end{split} \end{equation}

仮にl=0とすると、(12)N=4のFFTの各次元での計算式となります。
(N/4-1 = 0となり \Sigma が消え、W_{N/4} = W_1 = 1 となりW_{N/4}^{lk}が消えます)

\begin{equation} \begin{split} X_{0} & = x_0^{\color{purple}{\scriptscriptstyle 2m}} + x_1^{\color{purple}{\scriptscriptstyle 2m}} \\ X_{1} & = x_0^{\color{purple}{\scriptscriptstyle 2m+1}} + x_1^{\color{purple}{\scriptscriptstyle 2m+1}} \\ X_{2} & = (x_0^{\color{purple}{\scriptscriptstyle 2m}} - x_1^{\color{purple}{\scriptscriptstyle 2m}})W_{2}^0 \\ X_{3} & = (x_0^{\color{purple}{\scriptscriptstyle 2m+1}} - x_1^{\color{purple}{\scriptscriptstyle 2m+1}})W_{2}^0 \\ \end{split} \end{equation}

実際のプログラムでは (x_k^{\color{purple}{\scriptscriptstyle 2m}}, x_k^{\color{purple}{\scriptscriptstyle 2m+1}}), X_n の順に計算されます。 x_k^{\color{purple}{\scriptscriptstyle 2m}} ではW_N の計算は行われないため、元々4^2=16回必要だった W_N の計算回数は、 x_0^{\color{purple}{\scriptscriptstyle 2m+1}} , x_1^{\color{purple}{\scriptscriptstyle 2m+1}} , X_{2} , X_{3} の4回に削減されます!

そして、(12)の式をさらに変形し、l=0と同様にすると、
N=8のFFTの各次元での計算式となります。

\begin{equation} \begin{split} X_{0} & = x_0^{\color{purple}{\scriptscriptstyle 4l}} + x_1^{\color{purple}{\scriptscriptstyle 4l}} \\ X_{1} & = x_0^{\color{purple}{\scriptscriptstyle 4l+1}} + x_1^{\color{purple}{\scriptscriptstyle 4l+1}} \\ X_{2} & = x_0^{\color{purple}{\scriptscriptstyle 4l+2}} + x_1^{\color{purple}{\scriptscriptstyle 4l+2}} \\ X_{3} & = x_0^{\color{purple}{\scriptscriptstyle 4l+3}} + x_1^{\color{purple}{\scriptscriptstyle 4l+3}} \\ X_{4} & = (x_0^{\color{purple}{\scriptscriptstyle 4l}} - x_1^{\color{purple}{\scriptscriptstyle 4l}})W_{2}^0 \\ X_{5} & = (x_0^{\color{purple}{\scriptscriptstyle 4l+1}} - x_1^{\color{purple}{\scriptscriptstyle 4l+1}})W_{2}^0 \\ X_{6} & = (x_0^{\color{purple}{\scriptscriptstyle 4l+2}} - x_1^{\color{purple}{\scriptscriptstyle 4l+2}})W_{2}^0 \\ X_{7} & = (x_0^{\color{purple}{\scriptscriptstyle 4l+3}} - x_1^{\color{purple}{\scriptscriptstyle 4l+3}})W_{2}^0 \\ \end{split} \end{equation}

x_kについては以下となります。

\begin{equation} \begin{split} x_k^{\color{purple}{\scriptscriptstyle 4l}} & = x_k^{\color{purple}{\scriptscriptstyle 2m}} + x_{k+2}^{\color{purple}{\scriptscriptstyle 2m}} \\ x_k^{\color{purple}{\scriptscriptstyle 4l+1}} & = x_k^{\color{purple}{\scriptscriptstyle 2m+1}} + x_{k+2}^{\color{purple}{\scriptscriptstyle 2m+1}} \\ x_k^{\color{purple}{\scriptscriptstyle 4l+2}} & = (x_k^{\color{purple}{\scriptscriptstyle 2m}} - x_{k+2}^{\color{purple}{\scriptscriptstyle 2m}})W_{4}^k \\ x_k^{\color{purple}{\scriptscriptstyle 4l+3}} & = (x_k^{\color{purple}{\scriptscriptstyle 2m+1}} - x_{k+2}^{\color{purple}{\scriptscriptstyle 2m+1}})W_{4}^k \\ \end{split} \end{equation}

N=8では、元々8^2=64回必要だったW_N の計算回数は X_{4},X_{5},X_{6},X_{7},x_0^{\color{purple}{\scriptscriptstyle 4l+2}},x_1^{\color{purple}{\scriptscriptstyle 4l+2}},x_0^{\color{purple}{\scriptscriptstyle 4l+3}},x_1^{\color{purple}{\scriptscriptstyle 4l+3}},x_0^{\color{purple}{\scriptscriptstyle 2m+1}},x_2^{\color{purple}{\scriptscriptstyle 2m+1}},x_1^{\color{purple}{\scriptscriptstyle 2m+1}},x_3^{\color{purple}{\scriptscriptstyle 2m+1}}の12回となります...!

おおよそ規則性が理解できてきたかと思いますが、
X_n を奇数部と偶数部で繰り返し式変形することで、以下の計算量削減効果が得られます。

  • \Sigma の計算回数が式変形するごとに1/2となる
    • 2の累乗であれば、繰り返し行うことで最終的に計算回数が1回となり \Sigma が消える
      (FFTの入力フレーム数を2の累乗とするのはこのためです)
      • 式変形を繰り返すごとに変形前の計算結果を利用するため、 N 回の計算を各変形ごとに行う。変形回数は {\log_2}N であるため、計算回数の合計は N{\log_2}N 回となる
  • 偶数部の W_N の計算が不要になる
    • 式変形を繰り返す際、変形前の計算結果を利用することで、W_N の計算回数は各変形で一定して 1/2 される

よって、再帰的に式変形を続け、最終的に X_{N\circ} (e.g. X_{4l}) となるまで変形することで、計算回数をN/2{\log_2}N回にまで削減することができます。

これで、Cooley-Tukey型FFTルーチンによる高速化の説明は終わりです。ポイントは「偶数部と奇数部に分割して再帰的に計算回数を削減していく」です。

Rustでの実装

ここからは実装の説明です。

概要

説明に合わせて実装を行うと、以下のようになります。

  • x_k^{\color{purple}}を先に計算
    • x_k^{\color{purple}{\scriptscriptstyle 2m}},x_k^{\color{purple}{\scriptscriptstyle 2m+1}}を計算
    • 計算結果を利用して、x_k^{\color{purple}{\scriptscriptstyle 2(2l)}},x_k^{\color{purple}{\scriptscriptstyle 2(2l+1)}},x_k^{\color{purple}{\scriptscriptstyle 2(2l)+1}},x_k^{\color{purple}{\scriptscriptstyle 2(2l+1)+1}}を計算
      :
    • 2^\circ = Nとなるまで \circ 回繰り返す
  • x_k^{\color{purple}{\scriptscriptstyle \circ}}X_nに適用

ただ、そのまま実装すると冗長となるため、バタフライ演算で実装します。

バタフライ演算

バタフライ演算とは、下図のように、x_kから順に他インデックスの値を加算/減算し、最終的な結果X_kを得る演算方式です。

詳しくは引用元を参照いただければと思いますが、本実装で重要なのは、バタフライ演算の「特定のインデックスを交差して加算/減算するため、その計算結果を計算元のインデックスに代入しても計算結果に影響が出ない」という性質です。この性質を利用して元値を上書きしていく処理を「in-place演算」と呼び、追加メモリがほぼ不要となるため、処理の高速化も期待できます。

引用
引用元:https://www.onosokki.co.jp/HP-WK/eMM_back/emm140.pdf

実装例

以下が実装例になります。

pub fn cooley_tukey_fft(frames: &[Complex<f64>]) -> Vec<Complex<f64>> {
    let n = frames.len();
    // 先頭bitから2進数中のゼロを取得
    // nが2の累乗の場合"ゼロの個数=2の累乗時の指数"となる
    let n_power2 = n.trailing_zeros();
    // nが2の累乗であることを確認
    assert_eq!(n, 2usize.pow(n_power2));

    // 処理結果を格納するVector
    // 最初は入力フレームの値(x_k)を代入
    let mut x = frames.to_vec();

    // whileループごとのインデックス数(x_kにおけるkを決める)
    let mut calc_n = n;
    // Wにおける指数の固定値を保持
    let mut w_power = -2.0 * PI / (n as f64);
    while calc_n > 1 {
        let before_n = calc_n;
        calc_n >>= 1;
        for i in 0..calc_n {
            // W_N^kの算出
            let w = Complex::new(0.0, w_power * (i as f64)).exp();
            // バタフライ演算(=偶数部と奇数部の演算)
            //// whileループごとに処理パターンが増える(変形ごとに集約されたxを利用するため)
            //// e.g. n = 8
            ////  [calc_n = 4]
            ////    1パターン: x_k
            ////    i = k, 0 <= k <= N/2-1
            ////  [calc_n = 2]
            ////    2パターン: x^(2m), x^(2m+1)
            ////    i = m, 0 <= m <= N/4-1
            ////  [calc_n = 1]
            ////    4パターン: x^(2(2l)), x^(2(2l+1)+1), x^(2(2l)), x^(2(2l+1)+1)
            ////    i = l, 0 <= l <= N/8-1
            for begin_i in (0..n).step_by(before_n) {
                // 偶数部と奇数部のインデックスにはcalc_n分の間が空く
                let even_i = begin_i + i;
                let odd_i = begin_i + i + calc_n;
		// xの値を計算
                let x_even = x[even_i] + x[odd_i];
                let x_odd = (x[even_i] - x[odd_i]) * w;
                // 変形前の値を上書き(=in-place演算)
                x[even_i] = x_even;
                x[odd_i] = x_odd;
            }
        }
        w_power *= 2.0;
    }

    // xのindexをlog2(n)ビットでビット反転することで正しい順序を取得
    return bit_reversed_indexes(n_power2)
        .iter()
        .map(|&r_i| x[r_i])
        .collect();
}

実は、バタフライ演算を実施すると、計算結果が入力元の順番通りにはなりません。
正しい順序とするため、ビット反転インデックス作成メソッド(bit_reversed_indexes)を追加します。

fn bit_reversed_indexes(power: u32) -> Vec<usize> {
    let mut r_indexes: Vec<usize> = Vec::new();

    // 最初に /0b0*10*/ を代入 ((power+1)番目のビット数が1)
    let mut r_bit = 1 << power;

    // power番目以下のビットを扱う
    // 最初に /0b0*/ を代入
    r_indexes.push(0);
    while r_bit > 1 {
        // ビット数を1つ右にずらす
        r_bit >>= 1;
        for j in 0..r_indexes.len() {
            // whileループごとに、処理中の配列にr_bitを加算した値をすべてpushすることで、
            // indexに対してpower番目までのビットを反転させたindex配列が作成される
            // e.g. power = 3, r_bit = 2^3 = 8
            //  [r_bit = 8 = 0b1000]
            //    r[0] = r[0b000] = 0b000 ++
            //  [r_bit = 4 = 0b0100]
            //    r[0] = r[0b000] = 0b000
            //    r[1] = r[0b001] = 0b100 ++
            //  [r_bit = 2 = 0b0010]
            //    r[0] = r[0b000] = 0b000
            //    r[1] = r[0b001] = 0b100
            //    r[2] = r[0b010] = 0b010 ++
            //    r[3] = r[0b011] = 0b110 ++
            //  [r_bit = 1 = 0b0001]
            //    r[0] = r[0b000] = 0b000
            //    r[1] = r[0b001] = 0b100
            //    r[2] = r[0b010] = 0b010
            //    r[3] = r[0b011] = 0b110
            //    r[4] = r[0b100] = 0b001 ++
            //    r[5] = r[0b101] = 0b101 ++
            //    r[6] = r[0b110] = 0b011 ++
            //    r[7] = r[0b111] = 0b111 ++
            r_indexes.push(r_indexes[j] | r_bit);
        }
    }
    return r_indexes;
}

実際に動くFFTプログラムを作成できました!

周波数[Hz]の誤差

高速化の過程で、求める周波数[Hz]によっては数値がズレることがあります。各ベクトルの周波数[Hz]は、n/Nf_s(サンプリング周波数)[Hz]を掛けたものとなるので、 ±f_s/N[Hz]くらいを見ておけばいいかと思います。

さらなる高速化

バタフライ演算を適用したことで速くはなりましたが、さらなる高速化も可能です。

高速化(1): W_N^{k}の事前計算

W_N^{k}の値は周期性(=一定の周期で同じ値となる)を持っています。

W_N^{k+mN} = W_N^{k}e^{-i2{\pi}m} = W_N^{k} \times 1^m = W_N^{k}

この周期性を利用して、W_N^{k}をその都度計算するのではなく、
あらかじめ1周期分を計算し、その計算結果を各インデックスに周期性に合わせて格納することで計算量を削減します。

高速化(2): 事前計算のインスタンス化

実際の音声技術プログラムでは、単一フレーム処理を行うのではなく、1音声データから同一サイズの複数フレームを作成して変換します。そのため、サイズが既にわかっている場合、事前計算結果をインスタンス変数として各フレームで共用することで高速化が期待できます。

高速化(3): 入力フレームの配列自体をin-place演算

入力フレーム配列自体をin-place演算し、元データを保持しないことで高速化が期待できます。

高速化(4): 4の累乗での高速化

ここまでは、W_N^{N}=1W_N^{N/2}=-1のみの活用でしたが、オイラーの等式をさらに活用すると、追加で以下の2パターンの簡易計算が実現できます。(マイナスにより符号が逆なので注意!)

\begin{split} W_N^{N/4} & = e^{-i2{\pi}/N \times N/4} = e^{-i{\pi}/2} = -i \\ W_N^{3N/4} & = e^{-i2{\pi}/N \times 3N/4} = e^{-i3{\pi}/2} = i \\ \end{split}

これらを活用すれば、[n]\Rightarrow[4m,4m+1,4m+2,4m+3]\Rightarrow...と、計算回数をN/2{\log_2}N回からN/4{\log_4}N回と削減でき、更なる高速化を見込めます。

ただし、4の累乗にフレーム数を合わせると桁数が大幅に変わる可能性があるため、「2の累乗を、4の累乗と2の累乗に分割してそれぞれ高速化する」という、若干トリッキーな方法で実装します。

高速化(5): 最後のバタフライ演算の省略

最後のバタフライ演算では、計算部が必ずW_{N}^0 = 1となるため省略して計算可能です。

その他(6): floatのgenerics化

インスタンス化に併せて、float値のビット数を柔軟に変更できるようにします。

含めていないもの

コード簡略化のため、以下の内容は含めていません。

  • フーリエ逆変換
  • 2(4)の累乗以外での変換
  • 2次元配列の変換

実装

これらを踏まえた実装がこちらです。(Githubにも公開しています)

use num::Complex;
use num_traits::cast;
use num_traits::float::{Float, FloatConst};
use num_traits::identities::{one, zero};

pub struct Fft<T> {
    // フーリエ変換結果の次元数
    n: usize,
    // ビットリバースしたインデックスを付与
    r_indexes: Vec<usize>,
    // 事前計算したW_N^kを保持
    w: Vec<Complex<T>>,
}

// Float + FloatConst でfloat型を許可
impl<T: Float + FloatConst + std::fmt::Debug> Fft<T> {
    pub fn new() -> Self {
        Self {
            n: 0,
            r_indexes: vec![],
            w: vec![],
        }
    }

    pub fn setup(&mut self, n: usize) {
        let n_power2 = n.trailing_zeros();
        // nは2の累乗である前提
        assert_eq!(n, 1 << n_power2, "len of n should be 2^x");

        self.n = n;
        self.r_indexes = self.calc_bit_reversed_indexes(n);
        self.w = self.calc_w(n);
    }

    fn calc_bit_reversed_indexes(&mut self, n: usize) -> Vec<usize> {
        let n_power2 = n.trailing_zeros();
        // このメソッドは、nが2の累乗である前提で作成
        assert_eq!(n, 1 << n_power2, "len of n should be 2^x");

        let mut r_indexes = Vec::<usize>::with_capacity(n);

        // 最初に /0b0*10*/ を代入 ((power+1)番目のビット数が1)
        let mut r_bit = 1 << n_power2;

        // n_power2番目以下のビットを扱う
        // 2bitを1単位として扱い、最後に2が残る場合1bitで演算

        // e.g. n_power2 = 5, r_bit = 2^5 = 32
        //  [r_bit = 32 = 0b100000]
        //    r[ 0] = r[0b00000] = 0b00000 ++
        //  [r_bit = 8 = 0b001000]
        //    r[ 0] = r[0b00000] = 0b00000
        //    r[ 1] = r[0b000(01)] = 0b(01)000 ++
        //    r[ 2] = r[0b000(10)] = 0b(10)000 ++
        //    r[ 3] = r[0b000(11)] = 0b(11)000 ++
        //  [r_bit = 2 = 0b000(01)0]
        //    r[ 0] = r[0b00000] = 0b00000
        //    r[ 1] = r[0b000(01)] = 0b(01)000
        //    r[ 2] = r[0b000(10)] = 0b(10)000
        //    r[ 3] = r[0b000(11)] = 0b(11)000
        //    r[ 4] = r[0b0(01)(00)] = 0b(00)(01)0 ++
        //    r[ 5] = r[0b0(01)(01)] = 0b(01)(01)0 ++
        //    r[ 6] = r[0b0(01)(10)] = 0b(10)(01)0 ++
        //    r[ 7] = r[0b0(01)(11)] = 0b(11)(01)0 ++
        //    r[ 8] = r[0b0(10)(00)] = 0b(00)(10)0 ++
        //    r[ 9] = r[0b0(10)(01)] = 0b(01)(10)0 ++
        //     :
        //    r[15] = r[0b0(11)(11)] = 0b(11)(11)0 ++
        //  [r_bit = 1 = 0b000001]
        //    r[ 0] = r[0b00000] = 0b00000
        //    r[ 1] = r[0b000(01)] = 0b(01)000
        //     :
        //    r[15] = r[0b0(11)(11)] = 0b(11)(11)0
        //    r[16] = r[0b(1)(00)(00)] = 0b(00)(00)(1) ++
        //    r[17] = r[0b(1)(00)(01)] = 0b(01)(00)(1) ++
        //     :
        //    r[31] = r[0b(1)(11)(11)] = 0b(11)(11)(1) ++

        // 最初に /0b0*/ を追加
        r_indexes.push(0);
        // 4の累乗部として計算
        while r_bit > 2 {
            // ビット数を2つ右にずらす
            r_bit >>= 2;
            // 2bitを1単位とし、(01),(10),(11)を追加
            let len = r_indexes.len();
            for j in 0..len {
                r_indexes.push(r_indexes[j] | r_bit);
            }
            for j in 0..len {
                r_indexes.push(r_indexes[j] | r_bit << 1);
            }
            for j in 0..len {
                r_indexes.push(r_indexes[j] | r_bit | r_bit << 1);
            }
        }
        // 2が残る場合、最後に基底数2でバタフライ演算を行うため、indexも最後に算出
        if r_bit == 2 {
            for j in 0..r_indexes.len() {
                r_indexes.push(r_indexes[j] | 1);
            }
        }

        // in-place演算用にインデックスを加工
        return self.convert_indexes_as_inplace(r_indexes);
    }

    fn convert_indexes_as_inplace(&mut self, r_indexes: Vec<usize>) -> Vec<usize> {
        let mut nums = (0..r_indexes.len()).collect::<Vec<_>>();

        // 整数列のインデックスへr_indexesでswapを実行
        // その際のインデックスを保持して実データのswapに利用することで、結果的にr_indexesで整頓したことになる
        return (0..r_indexes.len())
            .map(|i| {
                let r_i = r_indexes[i];
                let swapped_r_i = (0..nums.len()).find(|&j| nums[j] == r_i).unwrap();
                nums.swap(i, swapped_r_i);
                swapped_r_i
            })
            .collect();
    }

    fn calc_w(&mut self, n: usize) -> Vec<Complex<T>> {
        let n_power2 = n.trailing_zeros();
        // このメソッドは、nが2の累乗である前提で作成
        assert_eq!(n, 1 << n_power2, "len of n should be 2^x");

        let mut w = Vec::with_capacity(n + 1);

        // W_N^0=1
        w.push(one());

        // nが2以下の場合
        if n <= 2 {
            if n == 2 {
                // W_N^N/2=-1
                w.push(cast(-1.0).unwrap());
            }
            // W_N^N=1
            w.push(one());
            return w;
        }

        // nが2以下ではない場合(=nが4で割り切れる場合)
        let q = n >> 2;
        let h = n >> 1;
        // 0~N/4(実直に計算)
        for i in 1..q {
            w.push(self.calc_part_w(n, i));
        }
        // W_N^N/4=-i
        w.push(-Complex::i());
        // N/4~N/2(計算結果を流用)
        for i in q + 1..h {
            let tmp = w[i - q];
            w.push(Complex::new(tmp.im, -tmp.re));
        }
        // W_N^N/2=-1
        w.push(cast(-1.0).unwrap());
        // N/2~N(計算結果を流用)
        for i in h + 1..n {
            let tmp = w[i - h];
            w.push(Complex::new(-tmp.re, -tmp.im));
        }
        // W_N^N=1
        w.push(one());

        return w;
    }

    fn calc_part_w(&mut self, n: usize, seq: usize) -> Complex<T> {
        // e^(-i2π)をN分割。seqごとの値を取得
        Complex::new(
            zero(),
            cast::<_, T>(-2.0).unwrap() * T::PI() / cast(n).unwrap() * cast(seq).unwrap(),
        )
        .exp()
    }

    pub fn process(&mut self, frames: &mut [Complex<T>]) {
        let len = frames.len();
        let len_power2 = len.trailing_zeros();
        // このメソッドは、フレーム数が2の累乗である前提で作成
        assert_eq!(len, 1 << len_power2, "len of frames should be 2^x");

        // 1フレーム以下の場合、そのまま返す
        if len <= 1 {
            return;
        }

        // フレーム数がnと異なる場合、nをlenに置き換える
        if len != self.n {
            self.setup(len);
        }

        self.inner_process(frames);
    }

    // 処理結果を格納するVectorはin-palce演算なので直接もらう
    fn inner_process(&mut self, x: &mut [Complex<T>]) {
        let n = x.len();

        // whileループごとのインデックス数(x_kにおけるkを決める)
        let mut calc_n = n;
        let mut calc_w_bit = 0;
        // 4 or 2になるまでバタフライ演算(2bit)
        while calc_n > 4 {
            let before_n = calc_n;
            calc_n >>= 2;
            for i in 0..calc_n {
                let w_i = i << calc_w_bit;
                let (w1, w2, w3) = (self.w[w_i], self.w[w_i << 1], self.w[w_i * 3]);
                for begin_i in (0..n).step_by(before_n) {
                    // 各インデックスにはcalc_n分の間が空く
                    let i0 = begin_i + i;
                    let i1 = i0 + calc_n;
                    let i2 = i1 + calc_n;
                    let i3 = i2 + calc_n;
                    // xの値を計算
                    // 以下の計算を短縮し高速化
                    // let x0 = x[i0] + x[i1] + x[i2] + x[i3];
                    // let x1 = (x[i0] = i() * x[i1] - x[i2] + i() * x[i3]) * w1;
                    // let x2 = (x[i0] - x[i1] + x[i2] - x[i3]) * w2;
                    // let x3 = (x[i0] + i() * x[i1] - x[i2] - i() * x[i3]) * w3;
                    let xi0_plus_xi2 = x[i0] + x[i2];
                    let xi0_minus_xi2 = x[i0] - x[i2];
                    let xi1_plus_xi3 = x[i1] + x[i3];
                    let xi1_minus_xi3 = x[i1] - x[i3];
                    let xi1_minus_xi3_i = Complex::new(-xi1_minus_xi3.im, xi1_minus_xi3.re);
                    // 変形前の値を上書き(=in-place演算)
                    x[i0] = xi0_plus_xi2 + xi1_plus_xi3;
                    x[i1] = (xi0_minus_xi2 - xi1_minus_xi3_i) * w1;
                    x[i2] = (xi0_plus_xi2 - xi1_plus_xi3) * w2;
                    x[i3] = (xi0_minus_xi2 + xi1_minus_xi3_i) * w3;
                }
            }
            calc_w_bit += 2;
        }

        // 最後のバタフライ演算は処理を簡略化
        if calc_n == 4 {
            for i0 in (0..n).step_by(calc_n) {
                let i1 = i0 + 1;
                let i2 = i1 + 1;
                let i3 = i2 + 1;
                // xの値を計算
                // 以下の計算を短縮し高速化
                // let x0 = x[i0] + x[i1] + x[i2] + x[i3];
                // let x1 = (x[i0] = i() * x[i1] - x[i2] + i() * x[i3]) * w1;
                // let x2 = (x[i0] - x[i1] + x[i2] - x[i3]) * w2;
                // let x3 = (x[i0] + i() * x[i1] - x[i2] - i() * x[i3]) * w3;
                let xi0_plus_xi2 = x[i0] + x[i2];
                let xi0_minus_xi2 = x[i0] - x[i2];
                let xi1_plus_xi3 = x[i1] + x[i3];
                let xi1_minus_xi3 = x[i1] - x[i3];
                let xi1_minus_xi3_i = Complex::new(-xi1_minus_xi3.im, xi1_minus_xi3.re);
                // 変形前の値を上書き(=in-place演算)
                x[i0] = xi0_plus_xi2 + xi1_plus_xi3;
                x[i1] = xi0_minus_xi2 - xi1_minus_xi3_i;
                x[i2] = xi0_plus_xi2 - xi1_plus_xi3;
                x[i3] = xi0_minus_xi2 + xi1_minus_xi3_i;
            }
        } else {
            // calc_n = 2
            for i in (0..n).step_by(calc_n) {
                let even_i = i;
                let odd_i = i + 1;
                // xの値を計算
                let x_even = x[even_i] + x[odd_i];
                let x_odd = x[even_i] - x[odd_i];
                // 変形前の値を上書き(=in-place演算)
                x[even_i] = x_even;
                x[odd_i] = x_odd;
            }
        }

        // in-place演算として高速化するため、ビット反転をメモリスワップで行う
        for (i, &r) in self.r_indexes.iter().enumerate() {
            x.swap(i, r);
        }
    }
}

速度計測

おおむね速度は以下のようになりました。
単純なdtftから約1000倍速くなってます。いいですね。

dtft:
 2,349,916 ns (≒ 2ms)
cooley_tukey_fft:
 13,333ns (≒ 13µs)
Fft:
 2,750ns (≒ 2µs)

※単一sin波(フレーム数:512, サンプリング周波数:16000[Hz])での計測

最後に

ここまで記事を読んでいただきありがとうございます。

FFTの記事を実際に取り組んで感じたのは、「FFTの説明を簡潔にするのは難しい」ということと、「高速化をオリジナルで思いつくのは難しい」ということです。

FFTがとっつきにくく見えるのは、奇数と偶数を交差したり、計算データを結果から逆算するといった「データフローの矢印が複雑である」ことと、バタフライ演算よりオイラーの公式のほうが理解に役立つのに、それがわかりにくいところかと思っています。計算式に苦手意識があるひとは、N=16あたりまで自分の手で書いてみると、理解が進むと思います。(自分も紙で計算しながら理解を進めました。)

また実装面では、速くしようとすればするほど既存記事やライブラリのロジックに寄ってしまい、自分の成果物とは言えないなと感じたため、ライブラリとして公開せず、簡潔な説明となるレベルに留めました。あわよくば、もっと頑張って実装してライブラリ公開を...!と思っていましたが、別の機会となりそうです。

それでは。

参考

https://cognicull.com/ja/f5q2jl62
https://ja.wikipedia.org/wiki/高速フーリエ変換
https://qiita.com/chalharu/items/14ee4bd8396792d5175c
https://github.com/chalharu/chfft
https://github.com/ejmahler/RustFFT
https://qiita.com/kob58im/items/c082a7904926aa8479e4
https://birdhouse.hateblo.jp/entry/2020/11/19/074342
https://www.kurims.kyoto-u.ac.jp/~ooura/fftman/index.html
https://www.gfd-dennou.org/arch/prepri/2002/hokudai/kazto/fft.htm
https://qiita.com/termoshtt/items/c24e7d67f3c4a9016a7b

Discussion