Rustでフーリエ変換(FFT)
モチベーション
音声技術の勉強として、Rustで高速フーリエ変換(FFT)を実装。
また、FFT関連の記事は玄人向けが多かったため、理解の手助けとなるよう、FFT自体の説明も記載しました。
(行列演算やバタフライ演算の知識を省き、記事内の実装に近いカタチで記載しています)
シンプルなフーリエ変換の実装については、前回記事を参照ください。
概要
この記事では「FFT=Cooley-Tukey型FFTルーチン」とします。
(最も一般的と思われる型を選択。他は機会があれば。。)
Cooley-Tukey型FFTルーチンの高速化理論と、Rustでの実装について説明します。
Cooley-Tukey型FFTルーチン
Cooley-Tukey型FFTルーチンにおいて、フーリエ変換がどう高速化されるのかを計算式で説明します。
基本式
まずフーリエ変換の基本式。
-
: フーリエ変換結果(N次元)X_n -
: フーリエ変換結果の次元n -
: 入力フレームの個数(2の累乗(理由は後述))N -
: 入力フレームの値x_k
前回記事との差分
前回記事では、簡単のため「N=
(前回実装したDTFTの計算量は
周波数[Hz]への変換
本記事でのフーリエ変換の結果は、周波数[Hz]単位ではなく
(
現状のままだと、
偶数部と奇数部にわける
計算量を削減するため、
以下の式を活用するので、先に示しておきます。
(
偶数部の計算
奇数部の計算
偶数部と同様に、
奇数部も、
再帰呼び出しによる高速化
偶数部も奇数部も、計算回数を
さらに計算回数を削減するためには、導き出した値をまた別のフーリエ変換と捉え、同様の計算式を再帰的に適用します。
まず偶数部(
(識別子を色付き(
この
偶数部(
奇数部(
以下のような結果となります。
そして、
仮に
(
実際のプログラムでは (
そして、
おおよそ規則性が理解できてきたかと思いますが、
-
の計算回数が式変形するごとに\Sigma となる1/2 - 2の累乗であれば、繰り返し行うことで最終的に計算回数が1回となり
が消える\Sigma
(FFTの入力フレーム数を2の累乗とするのはこのためです)- 式変形を繰り返すごとに変形前の計算結果を利用するため、
回の計算を各変形ごとに行う。変形回数はN であるため、計算回数の合計は{\log_2}N 回となるN{\log_2}N
- 式変形を繰り返すごとに変形前の計算結果を利用するため、
- 2の累乗であれば、繰り返し行うことで最終的に計算回数が1回となり
- 偶数部の
の計算が不要になるW_N - 式変形を繰り返す際、変形前の計算結果を利用することで、
の計算回数は各変形で一定してW_N される1/2
- 式変形を繰り返す際、変形前の計算結果を利用することで、
よって、再帰的に式変形を続け、最終的に
これで、Cooley-Tukey型FFTルーチンによる高速化の説明は終わりです。ポイントは「偶数部と奇数部に分割して再帰的に計算回数を削減していく」です。
Rustでの実装
ここからは実装の説明です。
概要
説明に合わせて実装を行うと、以下のようになります。
-
を先に計算x_k^{\color{purple}} -
x_k^{\color{purple}{\scriptscriptstyle 2m}} を計算,x_k^{\color{purple}{\scriptscriptstyle 2m+1}} - 計算結果を利用して、
,x_k^{\color{purple}{\scriptscriptstyle 2(2l)}} ,x_k^{\color{purple}{\scriptscriptstyle 2(2l+1)}} ,x_k^{\color{purple}{\scriptscriptstyle 2(2l)+1}} を計算x_k^{\color{purple}{\scriptscriptstyle 2(2l+1)+1}}
: -
となるまで2^\circ = N 回繰り返す\circ
-
-
をx_k^{\color{purple}{\scriptscriptstyle \circ}} に適用X_n
ただ、そのまま実装すると冗長となるため、バタフライ演算で実装します。
バタフライ演算
バタフライ演算とは、下図のように、
詳しくは引用元を参照いただければと思いますが、本実装で重要なのは、バタフライ演算の「特定のインデックスを交差して加算/減算するため、その計算結果を計算元のインデックスに代入しても計算結果に影響が出ない」という性質です。この性質を利用して元値を上書きしていく処理を「in-place演算」と呼び、追加メモリがほぼ不要となるため、処理の高速化も期待できます。
引用元:https://www.onosokki.co.jp/HP-WK/eMM_back/emm140.pdf
実装例
以下が実装例になります。
pub fn cooley_tukey_fft(frames: &[Complex<f64>]) -> Vec<Complex<f64>> {
let n = frames.len();
// 先頭bitから2進数中のゼロを取得
// nが2の累乗の場合"ゼロの個数=2の累乗時の指数"となる
let n_power2 = n.trailing_zeros();
// nが2の累乗であることを確認
assert_eq!(n, 2usize.pow(n_power2));
// 処理結果を格納するVector
// 最初は入力フレームの値(x_k)を代入
let mut x = frames.to_vec();
// whileループごとのインデックス数(x_kにおけるkを決める)
let mut calc_n = n;
// Wにおける指数の固定値を保持
let mut w_power = -2.0 * PI / (n as f64);
while calc_n > 1 {
let before_n = calc_n;
calc_n >>= 1;
for i in 0..calc_n {
// W_N^kの算出
let w = Complex::new(0.0, w_power * (i as f64)).exp();
// バタフライ演算(=偶数部と奇数部の演算)
//// whileループごとに処理パターンが増える(変形ごとに集約されたxを利用するため)
//// e.g. n = 8
//// [calc_n = 4]
//// 1パターン: x_k
//// i = k, 0 <= k <= N/2-1
//// [calc_n = 2]
//// 2パターン: x^(2m), x^(2m+1)
//// i = m, 0 <= m <= N/4-1
//// [calc_n = 1]
//// 4パターン: x^(2(2l)), x^(2(2l+1)+1), x^(2(2l)), x^(2(2l+1)+1)
//// i = l, 0 <= l <= N/8-1
for begin_i in (0..n).step_by(before_n) {
// 偶数部と奇数部のインデックスにはcalc_n分の間が空く
let even_i = begin_i + i;
let odd_i = begin_i + i + calc_n;
// xの値を計算
let x_even = x[even_i] + x[odd_i];
let x_odd = (x[even_i] - x[odd_i]) * w;
// 変形前の値を上書き(=in-place演算)
x[even_i] = x_even;
x[odd_i] = x_odd;
}
}
w_power *= 2.0;
}
// xのindexをlog2(n)ビットでビット反転することで正しい順序を取得
return bit_reversed_indexes(n_power2)
.iter()
.map(|&r_i| x[r_i])
.collect();
}
実は、バタフライ演算を実施すると、計算結果が入力元の順番通りにはなりません。
正しい順序とするため、ビット反転インデックス作成メソッド(bit_reversed_indexes)を追加します。
fn bit_reversed_indexes(power: u32) -> Vec<usize> {
let mut r_indexes: Vec<usize> = Vec::new();
// 最初に /0b0*10*/ を代入 ((power+1)番目のビット数が1)
let mut r_bit = 1 << power;
// power番目以下のビットを扱う
// 最初に /0b0*/ を代入
r_indexes.push(0);
while r_bit > 1 {
// ビット数を1つ右にずらす
r_bit >>= 1;
for j in 0..r_indexes.len() {
// whileループごとに、処理中の配列にr_bitを加算した値をすべてpushすることで、
// indexに対してpower番目までのビットを反転させたindex配列が作成される
// e.g. power = 3, r_bit = 2^3 = 8
// [r_bit = 8 = 0b1000]
// r[0] = r[0b000] = 0b000 ++
// [r_bit = 4 = 0b0100]
// r[0] = r[0b000] = 0b000
// r[1] = r[0b001] = 0b100 ++
// [r_bit = 2 = 0b0010]
// r[0] = r[0b000] = 0b000
// r[1] = r[0b001] = 0b100
// r[2] = r[0b010] = 0b010 ++
// r[3] = r[0b011] = 0b110 ++
// [r_bit = 1 = 0b0001]
// r[0] = r[0b000] = 0b000
// r[1] = r[0b001] = 0b100
// r[2] = r[0b010] = 0b010
// r[3] = r[0b011] = 0b110
// r[4] = r[0b100] = 0b001 ++
// r[5] = r[0b101] = 0b101 ++
// r[6] = r[0b110] = 0b011 ++
// r[7] = r[0b111] = 0b111 ++
r_indexes.push(r_indexes[j] | r_bit);
}
}
return r_indexes;
}
実際に動くFFTプログラムを作成できました!
周波数[Hz]の誤差
高速化の過程で、求める周波数[Hz]によっては数値がズレることがあります。各ベクトルの周波数[Hz]は、
さらなる高速化
バタフライ演算を適用したことで速くはなりましたが、さらなる高速化も可能です。
W_N^{k} の事前計算
高速化(1): この周期性を利用して、
あらかじめ1周期分を計算し、その計算結果を各インデックスに周期性に合わせて格納することで計算量を削減します。
高速化(2): 事前計算のインスタンス化
実際の音声技術プログラムでは、単一フレーム処理を行うのではなく、1音声データから同一サイズの複数フレームを作成して変換します。そのため、サイズが既にわかっている場合、事前計算結果をインスタンス変数として各フレームで共用することで高速化が期待できます。
高速化(3): 入力フレームの配列自体をin-place演算
入力フレーム配列自体をin-place演算し、元データを保持しないことで高速化が期待できます。
高速化(4): 4の累乗での高速化
ここまでは、
これらを活用すれば、
ただし、4の累乗にフレーム数を合わせると桁数が大幅に変わる可能性があるため、「2の累乗を、4の累乗と2の累乗に分割してそれぞれ高速化する」という、若干トリッキーな方法で実装します。
高速化(5): 最後のバタフライ演算の省略
最後のバタフライ演算では、計算部が必ず
その他(6): floatのgenerics化
インスタンス化に併せて、float値のビット数を柔軟に変更できるようにします。
含めていないもの
コード簡略化のため、以下の内容は含めていません。
- フーリエ逆変換
- 2(4)の累乗以外での変換
- 2次元配列の変換
実装
これらを踏まえた実装がこちらです。(Githubにも公開しています)
use num::Complex;
use num_traits::cast;
use num_traits::float::{Float, FloatConst};
use num_traits::identities::{one, zero};
pub struct Fft<T> {
// フーリエ変換結果の次元数
n: usize,
// ビットリバースしたインデックスを付与
r_indexes: Vec<usize>,
// 事前計算したW_N^kを保持
w: Vec<Complex<T>>,
}
// Float + FloatConst でfloat型を許可
impl<T: Float + FloatConst + std::fmt::Debug> Fft<T> {
pub fn new() -> Self {
Self {
n: 0,
r_indexes: vec![],
w: vec![],
}
}
pub fn setup(&mut self, n: usize) {
let n_power2 = n.trailing_zeros();
// nは2の累乗である前提
assert_eq!(n, 1 << n_power2, "len of n should be 2^x");
self.n = n;
self.r_indexes = self.calc_bit_reversed_indexes(n);
self.w = self.calc_w(n);
}
fn calc_bit_reversed_indexes(&mut self, n: usize) -> Vec<usize> {
let n_power2 = n.trailing_zeros();
// このメソッドは、nが2の累乗である前提で作成
assert_eq!(n, 1 << n_power2, "len of n should be 2^x");
let mut r_indexes = Vec::<usize>::with_capacity(n);
// 最初に /0b0*10*/ を代入 ((power+1)番目のビット数が1)
let mut r_bit = 1 << n_power2;
// n_power2番目以下のビットを扱う
// 2bitを1単位として扱い、最後に2が残る場合1bitで演算
// e.g. n_power2 = 5, r_bit = 2^5 = 32
// [r_bit = 32 = 0b100000]
// r[ 0] = r[0b00000] = 0b00000 ++
// [r_bit = 8 = 0b001000]
// r[ 0] = r[0b00000] = 0b00000
// r[ 1] = r[0b000(01)] = 0b(01)000 ++
// r[ 2] = r[0b000(10)] = 0b(10)000 ++
// r[ 3] = r[0b000(11)] = 0b(11)000 ++
// [r_bit = 2 = 0b000(01)0]
// r[ 0] = r[0b00000] = 0b00000
// r[ 1] = r[0b000(01)] = 0b(01)000
// r[ 2] = r[0b000(10)] = 0b(10)000
// r[ 3] = r[0b000(11)] = 0b(11)000
// r[ 4] = r[0b0(01)(00)] = 0b(00)(01)0 ++
// r[ 5] = r[0b0(01)(01)] = 0b(01)(01)0 ++
// r[ 6] = r[0b0(01)(10)] = 0b(10)(01)0 ++
// r[ 7] = r[0b0(01)(11)] = 0b(11)(01)0 ++
// r[ 8] = r[0b0(10)(00)] = 0b(00)(10)0 ++
// r[ 9] = r[0b0(10)(01)] = 0b(01)(10)0 ++
// :
// r[15] = r[0b0(11)(11)] = 0b(11)(11)0 ++
// [r_bit = 1 = 0b000001]
// r[ 0] = r[0b00000] = 0b00000
// r[ 1] = r[0b000(01)] = 0b(01)000
// :
// r[15] = r[0b0(11)(11)] = 0b(11)(11)0
// r[16] = r[0b(1)(00)(00)] = 0b(00)(00)(1) ++
// r[17] = r[0b(1)(00)(01)] = 0b(01)(00)(1) ++
// :
// r[31] = r[0b(1)(11)(11)] = 0b(11)(11)(1) ++
// 最初に /0b0*/ を追加
r_indexes.push(0);
// 4の累乗部として計算
while r_bit > 2 {
// ビット数を2つ右にずらす
r_bit >>= 2;
// 2bitを1単位とし、(01),(10),(11)を追加
let len = r_indexes.len();
for j in 0..len {
r_indexes.push(r_indexes[j] | r_bit);
}
for j in 0..len {
r_indexes.push(r_indexes[j] | r_bit << 1);
}
for j in 0..len {
r_indexes.push(r_indexes[j] | r_bit | r_bit << 1);
}
}
// 2が残る場合、最後に基底数2でバタフライ演算を行うため、indexも最後に算出
if r_bit == 2 {
for j in 0..r_indexes.len() {
r_indexes.push(r_indexes[j] | 1);
}
}
// in-place演算用にインデックスを加工
return self.convert_indexes_as_inplace(r_indexes);
}
fn convert_indexes_as_inplace(&mut self, r_indexes: Vec<usize>) -> Vec<usize> {
let mut nums = (0..r_indexes.len()).collect::<Vec<_>>();
// 整数列のインデックスへr_indexesでswapを実行
// その際のインデックスを保持して実データのswapに利用することで、結果的にr_indexesで整頓したことになる
return (0..r_indexes.len())
.map(|i| {
let r_i = r_indexes[i];
let swapped_r_i = (0..nums.len()).find(|&j| nums[j] == r_i).unwrap();
nums.swap(i, swapped_r_i);
swapped_r_i
})
.collect();
}
fn calc_w(&mut self, n: usize) -> Vec<Complex<T>> {
let n_power2 = n.trailing_zeros();
// このメソッドは、nが2の累乗である前提で作成
assert_eq!(n, 1 << n_power2, "len of n should be 2^x");
let mut w = Vec::with_capacity(n + 1);
// W_N^0=1
w.push(one());
// nが2以下の場合
if n <= 2 {
if n == 2 {
// W_N^N/2=-1
w.push(cast(-1.0).unwrap());
}
// W_N^N=1
w.push(one());
return w;
}
// nが2以下ではない場合(=nが4で割り切れる場合)
let q = n >> 2;
let h = n >> 1;
// 0~N/4(実直に計算)
for i in 1..q {
w.push(self.calc_part_w(n, i));
}
// W_N^N/4=-i
w.push(-Complex::i());
// N/4~N/2(計算結果を流用)
for i in q + 1..h {
let tmp = w[i - q];
w.push(Complex::new(tmp.im, -tmp.re));
}
// W_N^N/2=-1
w.push(cast(-1.0).unwrap());
// N/2~N(計算結果を流用)
for i in h + 1..n {
let tmp = w[i - h];
w.push(Complex::new(-tmp.re, -tmp.im));
}
// W_N^N=1
w.push(one());
return w;
}
fn calc_part_w(&mut self, n: usize, seq: usize) -> Complex<T> {
// e^(-i2π)をN分割。seqごとの値を取得
Complex::new(
zero(),
cast::<_, T>(-2.0).unwrap() * T::PI() / cast(n).unwrap() * cast(seq).unwrap(),
)
.exp()
}
pub fn process(&mut self, frames: &mut [Complex<T>]) {
let len = frames.len();
let len_power2 = len.trailing_zeros();
// このメソッドは、フレーム数が2の累乗である前提で作成
assert_eq!(len, 1 << len_power2, "len of frames should be 2^x");
// 1フレーム以下の場合、そのまま返す
if len <= 1 {
return;
}
// フレーム数がnと異なる場合、nをlenに置き換える
if len != self.n {
self.setup(len);
}
self.inner_process(frames);
}
// 処理結果を格納するVectorはin-palce演算なので直接もらう
fn inner_process(&mut self, x: &mut [Complex<T>]) {
let n = x.len();
// whileループごとのインデックス数(x_kにおけるkを決める)
let mut calc_n = n;
let mut calc_w_bit = 0;
// 4 or 2になるまでバタフライ演算(2bit)
while calc_n > 4 {
let before_n = calc_n;
calc_n >>= 2;
for i in 0..calc_n {
let w_i = i << calc_w_bit;
let (w1, w2, w3) = (self.w[w_i], self.w[w_i << 1], self.w[w_i * 3]);
for begin_i in (0..n).step_by(before_n) {
// 各インデックスにはcalc_n分の間が空く
let i0 = begin_i + i;
let i1 = i0 + calc_n;
let i2 = i1 + calc_n;
let i3 = i2 + calc_n;
// xの値を計算
// 以下の計算を短縮し高速化
// let x0 = x[i0] + x[i1] + x[i2] + x[i3];
// let x1 = (x[i0] = i() * x[i1] - x[i2] + i() * x[i3]) * w1;
// let x2 = (x[i0] - x[i1] + x[i2] - x[i3]) * w2;
// let x3 = (x[i0] + i() * x[i1] - x[i2] - i() * x[i3]) * w3;
let xi0_plus_xi2 = x[i0] + x[i2];
let xi0_minus_xi2 = x[i0] - x[i2];
let xi1_plus_xi3 = x[i1] + x[i3];
let xi1_minus_xi3 = x[i1] - x[i3];
let xi1_minus_xi3_i = Complex::new(-xi1_minus_xi3.im, xi1_minus_xi3.re);
// 変形前の値を上書き(=in-place演算)
x[i0] = xi0_plus_xi2 + xi1_plus_xi3;
x[i1] = (xi0_minus_xi2 - xi1_minus_xi3_i) * w1;
x[i2] = (xi0_plus_xi2 - xi1_plus_xi3) * w2;
x[i3] = (xi0_minus_xi2 + xi1_minus_xi3_i) * w3;
}
}
calc_w_bit += 2;
}
// 最後のバタフライ演算は処理を簡略化
if calc_n == 4 {
for i0 in (0..n).step_by(calc_n) {
let i1 = i0 + 1;
let i2 = i1 + 1;
let i3 = i2 + 1;
// xの値を計算
// 以下の計算を短縮し高速化
// let x0 = x[i0] + x[i1] + x[i2] + x[i3];
// let x1 = (x[i0] = i() * x[i1] - x[i2] + i() * x[i3]) * w1;
// let x2 = (x[i0] - x[i1] + x[i2] - x[i3]) * w2;
// let x3 = (x[i0] + i() * x[i1] - x[i2] - i() * x[i3]) * w3;
let xi0_plus_xi2 = x[i0] + x[i2];
let xi0_minus_xi2 = x[i0] - x[i2];
let xi1_plus_xi3 = x[i1] + x[i3];
let xi1_minus_xi3 = x[i1] - x[i3];
let xi1_minus_xi3_i = Complex::new(-xi1_minus_xi3.im, xi1_minus_xi3.re);
// 変形前の値を上書き(=in-place演算)
x[i0] = xi0_plus_xi2 + xi1_plus_xi3;
x[i1] = xi0_minus_xi2 - xi1_minus_xi3_i;
x[i2] = xi0_plus_xi2 - xi1_plus_xi3;
x[i3] = xi0_minus_xi2 + xi1_minus_xi3_i;
}
} else {
// calc_n = 2
for i in (0..n).step_by(calc_n) {
let even_i = i;
let odd_i = i + 1;
// xの値を計算
let x_even = x[even_i] + x[odd_i];
let x_odd = x[even_i] - x[odd_i];
// 変形前の値を上書き(=in-place演算)
x[even_i] = x_even;
x[odd_i] = x_odd;
}
}
// in-place演算として高速化するため、ビット反転をメモリスワップで行う
for (i, &r) in self.r_indexes.iter().enumerate() {
x.swap(i, r);
}
}
}
速度計測
おおむね速度は以下のようになりました。
単純なdtftから約1000倍速くなってます。いいですね。
dtft:
2,349,916 ns (≒ 2ms)
cooley_tukey_fft:
13,333ns (≒ 13µs)
Fft:
2,750ns (≒ 2µs)
※単一sin波(フレーム数:512, サンプリング周波数:16000[Hz])での計測
最後に
ここまで記事を読んでいただきありがとうございます。
FFTの記事を実際に取り組んで感じたのは、「FFTの説明を簡潔にするのは難しい」ということと、「高速化をオリジナルで思いつくのは難しい」ということです。
FFTがとっつきにくく見えるのは、奇数と偶数を交差したり、計算データを結果から逆算するといった「データフローの矢印が複雑である」ことと、バタフライ演算よりオイラーの公式のほうが理解に役立つのに、それがわかりにくいところかと思っています。計算式に苦手意識があるひとは、N=16あたりまで自分の手で書いてみると、理解が進むと思います。(自分も紙で計算しながら理解を進めました。)
また実装面では、速くしようとすればするほど既存記事やライブラリのロジックに寄ってしまい、自分の成果物とは言えないなと感じたため、ライブラリとして公開せず、簡潔な説明となるレベルに留めました。あわよくば、もっと頑張って実装してライブラリ公開を...!と思っていましたが、別の機会となりそうです。
それでは。
参考
Discussion