🌟

score 関数による Wasserstein 距離の bound の導出

2024/05/04に公開

Statementと概要

[3]では拡散過程に関してデータの分布p_0とscore関数によって得られる分布q_0間のWasserstein距離

W_2(\mu,\nu) := \inf \left\{ \int_{\mathbb{R}^d \times \mathbb{R}^d} \|x-y\|^2 d \gamma : \gamma \in \Pi (\mu,\nu) \right\}^{\frac{1}{2}}.

がscore関数s_\theta(x,t)とそのLipschiz定数L_f

\|f(x,t) - f(y,t) \| \leq L_f(t) \|x-y\|

と片側Lipschiz定数L_s

(s_\theta(x,t) - s_\theta(y,t))(x-y) \leq L_s(t) \|x-y\|^2

によって

W_2(p_0, q_0) \leq \int_0^T g(t)^2 I(t) \mathbb{E}_{p_t} \left[ \|\nabla \log p_t(x) - s_\theta(x,t)\|^2 \right]^{\frac{1}{2}} dt + I(T) W_2(p_T, q_T)

I(t):=exp(\int_0^t(L_f(r)+L_s(r)g(r)))

とboundされると主張している(g(t)は拡散係数,Theorem 1)。

[3]で引用されている[4]の定理の証明を含めてその導出を解説する

pdf版
https://drive.google.com/file/d/14OhWSc7bG-Gb6E-bGsXj0ZSTzd_6YIbe/view?usp=sharing

証明解説

Wasserstein距離の微分(Theorem 5.24 of [4])

連続の式 \partial_t \rho_t^i +\nabla\cdot(\mathbf{v}_t^i \rho_t^i )=0

に従い時間tによって発展する2つの分布\rho_t^1,\rho_t^2間のp-Wasserstein距離は

\frac{d}{dt}W_p^p(\rho_t^1,\rho_t^2)=\int\nabla \phi_t \mathbf{v}_t^1 \rho_t^1 dx +\int \nabla \psi_t \mathbf{v}_t^2 \rho_t^2 dx

\phi_t,\psi_tはKantorovich potential

証明(と説明):

Wasserstein距離を解とする分布関数\mu,\nu間の最適輸送問題

\inf_\gamma[ \int_{X \times Y} d(x,y)d\gamma(\mu,\nu)]

(\int_X d\gamma=\mu, \int_Y d\gamma=\nu )

の双対問題はKantorovich
potential\phi_t,\psi_t(未定定数に由来)を探す問題として

\max_{\phi,\psi}[\frac{1}{p}W_p^p(\mu,\nu) -\int_Y \phi \mu -\int_X \psi \nu ] \ge 0

(\frac{1}{p}W_p^p(\mu,\nu)=\int \phi d\mu +\int \psi d\nu -(\int_{X Y} (\psi + \phi -d(x,y))d\gamma(x,y))
から) と書くことができることから

\frac{d}{dt}W_p^p(\rho^1_t,\rho^2_t)|_{t=t_0}=\frac{d}{dt}(\int \phi_{t_0} \rho^1_t+\int \psi_{t_0}\rho^2_t)

=\int \phi_{t_0} \partial_t \rho^1_t|_{t_0} +\int \psi_{t_0} \partial_t \rho^2_t|_{t_0}

=-\int \phi_{t_0} \nabla(\rho^1_t\mathbf{v^1_t})|_{t_0}-\int \psi_{t_0} \nabla(\rho^2_t\mathbf{v^2_t})|_{t_0}

=\int \nabla\phi_{t_0} (\rho^1_{t_0}\mathbf{v^1_{t_0}})+\int \nabla\psi_{t_0} (\rho^2_{t_0}\mathbf{v^2_{t_0}})|

最適輸送Tに対して\nabla \phi_t(x)=x-T(x)=x-y, \nabla \psi_t(y)=y-S_t(y)とおくことで

=\int(x-y)(\mathbf{v_t^1}\rho_t^1(x)-\mathbf{v_t^2}\rho_t^2(y)) dx

と書ける。([3]論文の式(27))

この結果は連続の式がコンパクトな領域で成り立ち、任意のtに対して\rho^i_tが可測(\rho^i_t<<\mathscr{L}^d)、絶対連続の場合であり、
より一般に\rho^i_tがLipschizの場合に関して[@OTAM-cvgmt]では論じられている。

Lemma 2

E_{\pi_t}[(x-y)\cdot(\nabla \log q_t(y)-\nabla \log p_t(x))]は非正

説明:

Breinerの定理からp_tからq_tへの最適輸送写像の凸関数T_tに関するKantorovichポテンシャル\nabla\phi=T_tが存在して
連続な増加関数fとに対して半径Rの超球B_Rに対して

\int_{B_R x B_R}(x-y)(\frac{1}{q_t(y)}\nabla f(q_t)(y)-\frac{1}{p_t(x)}\nabla f(p_t)(x)) d\pi_t(x,y)

と書け、[1]のTheorem 1の証明の式(7)から

=\int_{B_R}\nabla(f(q_t))(y) (y-\nabla \phi^*_t)dy+\int_{B_R}\nabla(f(p_t))(x) (x-\nabla \phi_t)dx

=\int_{B_R}f(q_t)(\Delta \phi^*_t-d)dy -\int_{|y|=R} f(q_t)(\nabla \phi^*_t(y)\frac{y}{|y|}dy) \int_{B_R}f(p_t)(\Delta \phi_t-d)dx -\int_{|x|=R} f(p_t)(\nabla \phi_t(x)\frac{x}{|x|}dx)
(R\rightarrow\inftyとする)
=\int_{B_R}f(q_t)(\Delta \phi^*_t-d)dy +\int_{B_R}f(p_t)(\Delta \phi_t-d)dx

fとして恒等関数をとると\frac{1}{q_t(y)}\nabla f(q_t)=\log q_t(y)となり

E_{\pi_t}[(x-y)\cdot(\nabla \log q_t(y)-\nabla \log p_t(x))]=-E_{p_t}[\Delta \phi_t+\Delta\phi_t^*(\nabla \phi_t)-2d]

とかける(\phi_t^*\phi_tのconvex
condugete)。この最後の式は[@BOLLEY20122430]のLemma3.2では\nabla^2\phi(x)が(n次元の)直交行列Oと正定値対角行列DODO^*と書けることから

\nabla\phi^*(\nabla\phi(x))=x

\nabla^2\phi^*(\nabla\phi(x))\nabla^2\phi(x)=Id

\nabla^2\phi^*(\nabla\phi(x))=(\nabla^2\phi(x))^{-1}=OD^{-1}O^*

\Delta\phi(x)+\Delta\phi^*(x)(\nabla \phi(x))-2n=\sum_i d_i + \sum_i \frac{1}{d_i}-2n =\sum_i (d_i +\frac{1}{d_i}) \ge 0

となることから負になると説明されている。(次元が登場するところが曲率次元条件に似ている)

Lemma 1

E_{\pi_t}[(x-y)(v_{q_t}(y)-v_{p_t}(x))\le W_2(p_t,q_t)\{(L_f+L_sg^2)W_2(p_t,q_t)+g^2b^{1/2}\}]

b:=E_{p_t}[|\nabla \log p_t(x)-s_\theta (x,t)|^2]

証明解説

Fokker-Plank方程式から
E_{\pi_t}[(x-y)(v_{q_t}(y)-v_{p_t}(x))]=E{\pi_t}[(x-y)(f(y,t)-f(x,t))]
+g^2 E_{\pi_t}[(x-y)(\nabla \log p_t(x)-s_\theta(y,t))]+\frac{g^2}{2}E{\pi_t}[(x-y)(\nabla \log q_t(y)-\nabla \log p_t(x))]

Lemma 2から3番目の項は0以下になる。最初の項はfのLipschitz性から

E{\pi_t}[(x-y)(f(y,t)-f(x,t))]\le L_f E{\pi_t}[|x-y|^2]=L_f W_2^2 (p_t,q_t)

2番めの項 g^2 E_{\pi_t}[(x-y)(\nabla \log p_t(x)-s_\theta(y,t))]

I_1:=g^2 E_{\pi_t}[(x-y)(s_\theta(x,t)-s_\theta(y,t))]

I_2:=g^2 E_{\pi_t}[(x-y)(\nabla \log p_t(x)-s_\theta(x,t))]

コーシー・シュワルツの不等式から

I_2\le g^2 E_{\pi_t}[|x-y|^2]^{\frac{1}{2}} E_{\pi_t}[|\nabla \log p_t(x)-s_\theta(x,t)|^2]^\frac{1}{2}

さらに

E_{\pi_t}[|\nabla \log p_t(x)-s_\theta(x,t)|^2]=E_{p_t}[|\nabla \log p_t(x)-s_\theta(x,t)|^2]

なので

I_1+I_2\le g(t)W_2(p_t,q_t)\{L_sW_2(p_t,q_t)+b(t)^\frac{1}{2}\}

Theorem 1(主定理)

W_2(p_0, q_0) \leq \int_0^T g(t)^2 I(t) \mathbb{E}_{p_t} \left[ \|\nabla \log p_t(x) - s_\theta(x,t)\|^2 \right]^{\frac{1}{2}} dt + I(T) W_2(p_T, q_T)

証明:

Wasserstein距離の微分(Theorem 5.24 of [4])とLemma
1の両辺からW_2(p_t,q_t)を割って

-\frac{d}{dt}W_2(p_t,q_t)\le (L_f+L_sg^2)W_2(p_t,q_t)+g^2b^{1/2}

ここで

I(t):=\exp(\int_0^t (L_f+L_sg^2)dr)

b(t):=E_{p_t}[|\nabla \log p_t(x)-s_\theta(x,t)|^2]

と置くと\frac{d}{dt}I(t)=(L_f+L_sg^2)I(t)

-\frac{d}{dt}(I(t)W_2(p_t,q_t)) \le g^2b^{1/2}

これを積分して

I(0)W_2(p_0,q_0)-I(T)W_2(p_T,q_T) \le \int_0^T g(t)^2b(t)^{1/2}I(t)dt

I(0)=1とすると

W_2(p_0, q_0) \leq \int_0^T g(t)^2 I(t) \mathbb{E}_{p_t} \left[ \|\nabla \log p_t(x) - s_\theta(x,t)\|^2 \right]^{\frac{1}{2}} dt + I(T) W_2(p_T, q_T)

が結論付けられる。

注意点

Lipschiz定数の推定はNP hard問題らしい

感想

Wasserstein距離がscore関数の2乗誤差で抑えられるのは統一的な見方ができるのかもしれないが、
一方でWasserstein score
functionというものも提唱されていて[5]どういう関係があるのだろうか。

次元が絡むところから曲率次元条件と何らかの関係があるのかもしれない。

参考文献

[1] Fran¸cois Bolley and Jos´e A. Carrillo. Nonlinear diffusion: Geodesic convexity is equivalent
to wasserstein contraction, 2014.
https://arxiv.org/abs/1309.1932

[2] Fran¸cois Bolley, Ivan Gentil, and Arnaud Guillin. Convergence to equilibrium in wasserstein distance for fokker–planck equations. Journal of Functional Analysis, Vol. 263, No. 8,
pp. 2430–2457, 2012.
https://arxiv.org/abs/1110.3606

[3] Dohyun Kwon, Ying Fan, and Kangwook Lee. Score-based generative modeling secretly minimizes the wasserstein distance, 2022.
https://arxiv.org/abs/2212.06359

[4] Filippo Santambrogio. Optimal transport for applied mathematicians. calculus of variations, pdes and modeling. 2015.
https://link.springer.com/book/10.1007/978-3-319-20828-2

[5] Amari Shun-ichi and Matsuda Takeru. Wasserstein statistics in one-dimensional location
scale models, 2022.
https://link.springer.com/article/10.1007/s10463-021-00788-1#citeas

Discussion