〰️

FFTの⊗的直観

2024/12/04に公開

離散フーリエ変換の高速アルゴリズムであるFFTはなぜ速いのか? について、テンソル積の意味でローカルな演算子に分解しているから、という直観的理解をする。実質的に量子フーリエ変換の回路をエミュレーションしていることになる。

DFT

N次元の離散フーリエを以下とする(規格化を一旦考えない)。

F[f](m)= \sum_{0\le n < N}f(n)\exp\left(-2\pi i \frac{nm}{N}\right)

逆変換は複素共役で

F^{-1}[g](n)= \sum_{0\le m < N}g(m)\exp\left(2\pi i \frac{nm}{N}\right)

規格化を除いてユニタリとなる。

F^{-1}F[f](n) = \sum_{0 \le m,n^\prime < N}f(n^\prime)\exp\left(-2\pi i \frac{m(n^\prime - n)}{N}\right) = \sum_{0 \le n^\prime < N} f(n^\prime)\delta_{nn^\prime}N = Nf(n)

FFT

DFTの係数行列は密なので、そのままだとO(N^2)の計算コストがかかる。FFTはN=N_kN_{k-1}\dots N_1へ因数分解できるとき、O((N_k+N_{k-1}+\dots N_1+k)(N_kN_{k-1}\dots N_1))へ高速化する。

この計算がいまいちピンと来なかったが、これは分割統治を テンソル積の意味で やっているとみなすと直観に適う。これを説明してみる。

N=N_kN_{k-1}\dots N_1の因数分解を前提とするということは、信号空間もまた\mathbb{C}^N = \mathbb{C}^{N_k}\otimes \mathbb{C}^{N_{k-1}}\otimes \dots \mathbb{C}^{N_1}のようなテンソル積分解ができることを意味する。ここで重要なのが、基底のインデックスを工夫することだ。

DFTの積分核には基底の番号が現れるので、\mathbb{C}^{N_k}\otimes \mathbb{C}^{N_{k-1}}\otimes \dots \mathbb{C}^{N_1}の基底ベクトルのインデックスを一つの整数インデックスに対応付ける必要がある。テンソル積分解された信号空間の基底ベクトルは、0 \le n_i < N_iからなるk個の添字で指定できる。そこでこの添字から単一整数添字への単射を作ればよい。

ここで二通りのインデックスを考える。

\begin{align*} \mathrm{IL}(k|n_k,n_{k-1}\dots n_1) = & n_k (N_{k-1}\dots N_1) + n_{k-1}(N_{k-2}\dots N_1) + \dots + n_2 N_1 + n_1 \\ = & \sum_{i=1}^k n_i \prod_{j=1}^{i-1} N_j\\ \mathrm{IR}(k|m_k,m_{k-1}\dots m_1) = & m_k + N_k m_{k-1} + \dots + (N_k\dots N_3)m_2 + (N_k\dots N_2)m_1 \\ = & \sum_{i=1}^k m_i \prod_{j=i+1}^k N_j \end{align*}

直感的には\mathrm{IL}は、各桁がN_k\dots N_1である位取りにしたがって、左をMSB、つまり最大桁としたものとなる。\mathrm{IR}は逆に右が最大桁となる。

このインデックスによるDFTを実行する。ここで、入力側の基底インデックスと出力側の基底インデックスを一旦別の物にする。 これが計算で効いてくる。テンソル積分解された基底添字で表されたインデックスをDFTに代入すると

\begin{align*} & F[f](\mathrm{IR}(k|m_k,\dots m_1)) \\ = & \sum_{\{n_i | 0\le n_i < N_i\}_i} f(\mathrm{IL}(k|n_k,\dots n_1)) \exp\left(-2\pi i \frac{\mathrm{IL}(k|n_k,\dots n_1)\mathrm{IR}(k|m_k,\dots m_1)}{N_k\dots N_1}\right) \end{align*}

となる。これを計算していく。

積分核の分解

積分核の指数にある

\frac{\mathrm{IL}(k|n_k,\dots n_1)\mathrm{IR}(k|m_k,\dots m_1)}{N_k\dots N_1}

の振る舞いが重要になる。これは以下のように再帰的に分解できる。ただし、ここに2\pi iがかかって指数関数に入れられるので、\mod 1で計算する。

\begin{align*} & \frac{\mathrm{IL}(k|n_k,\dots n_1)\mathrm{IR}(k|m_k,\dots m_1)}{N_k\dots N_1}\\ = & \frac{1}{N_k\dots N_1}\left(n_k(N_{k-1}\dots N_1)+\mathrm{IL}(k-1|n_{k-1},\dots n_1) \right)\left(m_k + N_k\mathrm{IR}(k-1|m_{k-1},\dots m_1) \right)\\ = & \frac{\mathrm{IL}(k-1|n_{k-1},\dots n_1)\mathrm{IR}(k-1|m_{k-1},\dots m_1)}{N_{k-1}\dots N_1} + \frac{m_k\mathrm{IL}(k-1|n_{k-1},\dots n_1) }{N_k\dots N_1} + \frac{m_kn_k}{N_k} +n_k\mathrm{IR}(k-1|m_{k-1},\dots m_1)\\ = & \frac{\mathrm{IL}(k-1|n_{k-1},\dots n_1)\mathrm{IR}(k-1|m_{k-1},\dots m_1)}{N_{k-1}\dots N_1} + \frac{m_k\mathrm{IL}(k-1|n_{k-1},\dots n_1) }{N_k\dots N_1} + \frac{m_kn_k}{N_k} \end{align*}

最後の初項にkが1ズレたものが出てきた。したがってこれを繰り返すと次のようになる。

\begin{align*} & \frac{\mathrm{IL}(k|n_k,\dots n_1)\mathrm{IR}(k|m_k,\dots m_1)}{N_k\dots N_1}\\ = & \frac{m_1n_1}{N_1} \\ + & \frac{m_2\mathrm{IL}(1|n_1) }{N_2N_1} + \frac{m_2n_2}{N_2} \\ + & \frac{m_3\mathrm{IL}(2|n_2,n_1) }{N_3N_2N_1} + \frac{m_3n_3}{N_3} \\ & \cdots \\ + & \frac{m_{k-1}\mathrm{IL}(k-2|n_{k-2},\dots n_1) }{N_{k-1}\dots N_1} + \frac{m_{k-1}n_{k-1}}{N_{k-1}} \\ + & \frac{m_k\mathrm{IL}(k-1|n_{k-1},\dots n_1) }{N_k\dots N_1} + \frac{m_kn_k}{N_k} \end{align*}

この式は複雑だが、\frac{m_in_i}{N_i}の項の手前ではn_iはもはや出現せず、この項の後にはm_iは出現しない、という良い性質がある。これによってn_iによる足し上げを逐次的に行ってよくなる。

\begin{align*} & F[f](\mathrm{IR}(k|m_k,\dots m_1)) \\ = & \sum_{0 \le n_1 < N_1}\exp\left(-2\pi i \frac{m_1n_1}{N_1} \right)\\ \times & \exp\left(-2\pi i\frac{m_2\mathrm{IL}(1|n_1)}{N_2N_1}\right)\sum_{0 \le n_2 < N_2}\exp\left(-2\pi i \frac{m_2n_2}{N_2} \right) \\ \times & \exp\left(-2\pi i \frac{m_3\mathrm{IL}(2|n_2,n_1) }{N_3N_2N_1} \right) \sum_{0 \le n_3 < N_3}\exp\left(-2\pi i \frac{m_3n_3}{N_3} \right)\\ & \dots \\ \times & \exp\left(-2\pi i \frac{m_{k-1}\mathrm{IL}(k-2|n_{k-2},\dots n_1) }{N_{k-1}\dots N_1}\right) \sum_{0 \le n_{k-1} < N_{k-1}} \exp\left(-2\pi i \frac{m_{k-1}n_{k-1}}{N_{k-1}} \right) \\ \times & \exp\left(-2\pi i \frac{m_k\mathrm{IL}(k-1|n_{k-1},\dots n_1) }{N_k\dots N_1} \right)\sum_{0 \le n_k < N_k} \exp\left(-2\pi i \frac{m_kn_k}{N_k} \right) \\ \times & f(\mathrm{IL}(k|n_k,\dots n_1)) \end{align*}

ここで、総和記号による足し上げは、それ以降のすべての項に掛かっていることに注意。
説明としては\frac{m_in_i}{N_i}N_i次元のDFTであるから、N次元DFTが、その因数サイズのDFT+\alphaに分解できたことになるが、しかしまだ複雑で一体どうやって加速されているのかピンとこない。

StringDiagramでのバタフライ演算

ここで、String Diagramの出番だ。この式をみると、演算の対象となっているのはf(\mathrm{IL}(k|n_k,\dots n_1))で、その成分添字は、\sum_{n_i}の前後でm_iに入れ替わっている。つまり、N_i次元DFTが実行されて添字が入れ替わった。それ以外の

\mathrm{Tw}(j) := \exp\left(-2\pi i \frac{m_j\mathrm{IL}(j-1|n_{j-1},\dots n_1) }{N_j\dots N_1}\right)

は複雑ではあるが、この項の前後で添字は置き換わってない。これはこの項はこの瞬間採用している基底ベクトルに対しては掛け算作用素のようになっていることを意味する。n_i \leftrightarrow m_iN_i次元DFTが実施されているときは、別にN_i次元の演算が起きているわけではない。この瞬間でもn_i,m_i以外の添字はずっと存在しているから、演算そのものはずっとN次元だ。N_i次元DFTがここに出現しているのは、その演算がn_i,m_i以外の添字には依存しないということで、つまりn_i,m_i以外の添字についても均一に作用しているということだ。これはテンソル積空間において、\mathrm{id}\otimes Fのような演算になっていることになる。

そういうことで、Fは、String Diagramで書くなら次のようになっている。

ただし、先のやり方で計算すると計算結果の添字はm_iつまり、\mathrm{IR}によるインデックスとして得られるので、インデックスを揃える写像を追加した。これは成分の並び替えだけなのでO(N)で実施できる。

これは量子フーリエ変換の回路と似ているが、そもそも数学的に同じものである。

よくFFTの演算はバタフライ演算と呼ばれるが、テンソル積分解とString Diagramの観点からは、テンソル積の意味でセパラブル/ローカルな演算子(と、掛け算作用素)への帰着が本質的だとわかる。一般の線形作用素はN^2の計算量がかかるが、セパラブルであればそれぞれのテンソル成分に関してだけ二乗であるような作用素に分解できるからだ。

例えば

F\otimes G = (\mathrm{id}\otimes G)\circ (F\otimes \mathrm{id})

は作用素テンソルの恒等式だが、計算量の観点では

(\mathrm{dim}F\times\mathrm{dim}G)^2 \\ \neq \mathrm{dim}F×(\mathrm{dim}G)^2 + (\mathrm{dim}F)^2×\mathrm{dim}G \\ = (\mathrm{dim}F +\mathrm{dim}G)(\mathrm{dim}F\times\mathrm{dim}G)

と等しくない。この差を利用できるためには、演算子がテンソル積の意味でローカルな演算子に分解できればよい。

最後に計算量見積もりを確認しよう。N_i次元DFTはO(N_i^2)だが、自明な作用が他の成分に作用しているから、これに他のN_jの1次がかかる。つまりそれぞれについてNN_iだ。\mathrm{Tw}(j)は掛け算なので\prod_i N_i = Nである。これがk-1層あるので、オーダーで見ればO((N_k+N_{k-1}+\dots N_1+k)(N_kN_{k-1}\dots N_1))となる。

これは分かりづらいが、よくあるN_i = 2, N=2^kと単純化する場合は

O((N_k+N_{k-1}+\dots N_1+k)(N_kN_{k-1}\dots N_1)) \\ = O((2k+k)(2^k))=O(N\log N)

となり、O(N^2)よりも高速になる。

(追記)実装してみる

N_i = 2で揃え、入力信号のl番目の成分を、まずl=\mathrm{IL}(k|n_k,\dots,n_1)だとする。そうするとn_iを参照するには、lを二進表記してビットにアクセスすればいい。

あとはString Diagramの通りに、jビット目の2d-DFTと、\mathrm{Tw}(j)を交互に実施すればよい。再帰は不要で、ループで処理する。jビット目の2d-DFTと言っても、テンソル積になっているから、すべての添字を舐める間に、添字のjビット目の値だけを置き換えてアクセスすることで実視する。

処理が完了すると、n_i添字がm_i添字に入れ替わるが、m_i添字は\mathrm{IR}(k|m_k,\dots,m_1)のそれなので、そのままアクセスすると順番が崩れている。添字をビット列反転して差し替えることでこれを修正する。

import numpy as np

# size幅のプレーンなビット列としてreverseしたintを返す。
def bitreverse(idx:int,size:int)->int:
    return int(f"{idx:0{size}b}"[::-1],2)

# 下位atビットまでを抽出した整数を返す。0-origin
def lowermasked(idx:int, at:int)->int:
    mask = (1 << (at+1)) - 1
    return idx & mask

# k-bit(0-origin)の値を返す。なお0がfalse
def getBit(idx:int,at:int) -> bool:
    return (idx & (1 << at)) != 0

# テンソル成分に対する2ddft(実質hadamard)
def hadamard(input_array,at:int,bitSize:int):
    new_array = np.zeros(input_array.size, dtype=np.complex128)
    for i in range(0,1<<bitSize):
        # アダマール(dim2DFT)を実施
        if getBit(i,at): # 立ってる
            new_array[i] = input_array[i & ~(1<<at)] - input_array[i | (1<<at)]
        else: # 折れてる
            new_array[i] = input_array[i & ~(1<<at)] + input_array[i | (1<<at)]
    return new_array

# 位相適用
def twidle(input_array,at:int,bitSize:int,conjecture:bool):
    new_array = np.zeros(input_array.size, dtype=np.complex128)
    for i in range(0,1<<bitSize):
        if getBit(i,at):
            # twidleを適用
            twidleFactor = (-1 if conjecture else 1)* lowermasked(i,at) / (1<<(at+1)) # at は0-origin
            new_array[i] = input_array[i] * np.exp(-2*np.pi*1j*twidleFactor)
        else:
            new_array[i] = input_array[i]
    return new_array

# 添字のbitreverse
def arrangeIndex(input_array,bitSize:int):
    new_array = np.zeros(input_array.size, dtype=np.complex128)
    for i in range(0,1<<bitSize):
        new_array[i] = input_array[bitreverse(i,bitSize)]
    return new_array

# fft実施。サブルーチンはO(N),繰り返しはbitSize = log NなのでNlogN
def fft(input_array,bitSize:int,conjecture:bool):
    array = input_array.copy()
    for j in reversed(range(0,bitSize)):
        array = hadamard(array,j,bitSize)
        if j != 0:
            array = twidle(array,j,bitSize,conjecture)
    return arrangeIndex(array,bitSize)

chatgpt氏に適当にテストコードを書いてもらおう。

import matplotlib.pyplot as plt

# 以下はGPT氏にテスト関数を書いてもらう
def generate_sine_wave(frequency, sample_rate, duration):
    """
    正弦波データを生成
    :param frequency: 周波数(Hz)
    :param sample_rate: サンプリング周波数(Hz)
    :param duration: 信号の長さ(秒)
    :return: 時間軸, 正弦波信号
    """
    t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False)
    signal = np.sin(2 * np.pi * frequency * t)
    return t, signal

def plot_spectrum(signal, sample_rate, fft_func, title="Spectrum"):
    """
    信号のスペクトルをプロット
    :param signal: 入力信号
    :param sample_rate: サンプリング周波数(Hz)
    :param fft_func: FFT関数
    :param title: プロットのタイトル
    """
    N = len(signal)
    fft_result = fft_func(signal, int(np.log2(N)), conjecture=False)
    freq = np.fft.fftfreq(N, 1 / sample_rate)

    plt.figure(figsize=(10, 6))
    plt.plot(freq[:N // 2], np.abs(fft_result)[:N // 2])  # 正の周波数のみプロット
    plt.title(title)
    plt.xlabel("Frequency (Hz)")
    plt.ylabel("Amplitude")
    plt.grid()
    plt.show()

sample_rate = 1024  # サンプリング周波数 (Hz)
duration = 1.0  # 信号の長さ (秒)
freq1, freq2 = 50, 200  # 正弦波の周波数 (Hz)

t, sine_wave1 = generate_sine_wave(freq1, sample_rate, duration)
_, sine_wave2 = generate_sine_wave(freq2, sample_rate, duration)
signal = sine_wave1 + sine_wave2  # 2つの正弦波を加算

# 自作FFTの結果をプロット
plot_spectrum(signal, sample_rate, fft, title="Spectrum (Custom FFT)") # 自作の

# NumPy FFTの結果をプロット
def numpy_fft_wrapper(signal, bitSize, conjecture):
    return np.fft.fft(signal) # npについてるやつ

plot_spectrum(signal, sample_rate, numpy_fft_wrapper, title="Spectrum (NumPy FFT)") 


Discussion