Zenn
🌟

score 関数による Wasserstein 距離の bound の導出

に公開

Statementと概要

[3]では拡散過程に関してデータの分布p0p_0とscore関数によって得られる分布q0q_0間のWasserstein距離

W2(μ,ν):=inf{Rd×Rdxy2dγ:γΠ(μ,ν)}12.W_2(\mu,\nu) := \inf \left\{ \int_{\mathbb{R}^d \times \mathbb{R}^d} \|x-y\|^2 d \gamma : \gamma \in \Pi (\mu,\nu) \right\}^{\frac{1}{2}}.

がscore関数sθ(x,t)s_\theta(x,t)とそのLipschiz定数LfL_f

f(x,t)f(y,t)Lf(t)xy\|f(x,t) - f(y,t) \| \leq L_f(t) \|x-y\|

と片側Lipschiz定数LsL_s

(sθ(x,t)sθ(y,t))(xy)Ls(t)xy2(s_\theta(x,t) - s_\theta(y,t))(x-y) \leq L_s(t) \|x-y\|^2

によって

W2(p0,q0)0Tg(t)2I(t)Ept[logpt(x)sθ(x,t)2]12dt+I(T)W2(pT,qT)W_2(p_0, q_0) \leq \int_0^T g(t)^2 I(t) \mathbb{E}_{p_t} \left[ \|\nabla \log p_t(x) - s_\theta(x,t)\|^2 \right]^{\frac{1}{2}} dt + I(T) W_2(p_T, q_T)

I(t):=exp(0t(Lf(r)+Ls(r)g(r)))I(t):=exp(\int_0^t(L_f(r)+L_s(r)g(r)))

とboundされると主張している(g(t)は拡散係数,Theorem 1)。

[3]で引用されている[4]の定理の証明を含めてその導出を解説する

pdf版
https://drive.google.com/file/d/14OhWSc7bG-Gb6E-bGsXj0ZSTzd_6YIbe/view?usp=sharing

証明解説

Wasserstein距離の微分(Theorem 5.24 of [4])

連続の式 tρti+(vtiρti)=0\partial_t \rho_t^i +\nabla\cdot(\mathbf{v}_t^i \rho_t^i )=0

に従い時間tによって発展する2つの分布ρt1,ρt2\rho_t^1,\rho_t^2間のp-Wasserstein距離は

ddtWpp(ρt1,ρt2)=ϕtvt1ρt1dx+ψtvt2ρt2dx\frac{d}{dt}W_p^p(\rho_t^1,\rho_t^2)=\int\nabla \phi_t \mathbf{v}_t^1 \rho_t^1 dx +\int \nabla \psi_t \mathbf{v}_t^2 \rho_t^2 dx

ϕt,ψt\phi_t,\psi_tはKantorovich potential

証明(と説明):

Wasserstein距離を解とする分布関数μ,ν\mu,\nu間の最適輸送問題

infγ[X×Yd(x,y)dγ(μ,ν)]\inf_\gamma[ \int_{X \times Y} d(x,y)d\gamma(\mu,\nu)]

(Xdγ=μ,Ydγ=ν)(\int_X d\gamma=\mu, \int_Y d\gamma=\nu )

の双対問題はKantorovich
potentialϕt,ψt\phi_t,\psi_t(未定定数に由来)を探す問題として

maxϕ,ψ[1pWpp(μ,ν)YϕμXψν]0\max_{\phi,\psi}[\frac{1}{p}W_p^p(\mu,\nu) -\int_Y \phi \mu -\int_X \psi \nu ] \ge 0

(1pWpp(μ,ν)=ϕdμ+ψdν(XY(ψ+ϕd(x,y))dγ(x,y))\frac{1}{p}W_p^p(\mu,\nu)=\int \phi d\mu +\int \psi d\nu -(\int_{X Y} (\psi + \phi -d(x,y))d\gamma(x,y))
から) と書くことができることから

ddtWpp(ρt1,ρt2)t=t0=ddt(ϕt0ρt1+ψt0ρt2)\frac{d}{dt}W_p^p(\rho^1_t,\rho^2_t)|_{t=t_0}=\frac{d}{dt}(\int \phi_{t_0} \rho^1_t+\int \psi_{t_0}\rho^2_t)

=ϕt0tρt1t0+ψt0tρt2t0=\int \phi_{t_0} \partial_t \rho^1_t|_{t_0} +\int \psi_{t_0} \partial_t \rho^2_t|_{t_0}

=ϕt0(ρt1vt1)t0ψt0(ρt2vt2)t0=-\int \phi_{t_0} \nabla(\rho^1_t\mathbf{v^1_t})|_{t_0}-\int \psi_{t_0} \nabla(\rho^2_t\mathbf{v^2_t})|_{t_0}

=ϕt0(ρt01vt01)+ψt0(ρt02vt02)=\int \nabla\phi_{t_0} (\rho^1_{t_0}\mathbf{v^1_{t_0}})+\int \nabla\psi_{t_0} (\rho^2_{t_0}\mathbf{v^2_{t_0}})|

最適輸送Tに対してϕt(x)=xT(x)=xy,ψt(y)=ySt(y)\nabla \phi_t(x)=x-T(x)=x-y, \nabla \psi_t(y)=y-S_t(y)とおくことで

=(xy)(vt1ρt1(x)vt2ρt2(y))dx=\int(x-y)(\mathbf{v_t^1}\rho_t^1(x)-\mathbf{v_t^2}\rho_t^2(y)) dx

と書ける。([3]論文の式(27))

この結果は連続の式がコンパクトな領域で成り立ち、任意のtに対してρti\rho^i_tが可測(ρti<<Ld)\rho^i_t<<\mathscr{L}^d)、絶対連続の場合であり、
より一般にρti\rho^i_tがLipschizの場合に関して[@OTAM-cvgmt]では論じられている。

Lemma 2

Eπt[(xy)(logqt(y)logpt(x))]E_{\pi_t}[(x-y)\cdot(\nabla \log q_t(y)-\nabla \log p_t(x))]は非正

説明:

Breinerの定理からptp_tからqtq_tへの最適輸送写像の凸関数TtT_tに関するKantorovichポテンシャルϕ=Tt\nabla\phi=T_tが存在して
連続な増加関数fとに対して半径Rの超球BRB_Rに対して

BRxBR(xy)(1qt(y)f(qt)(y)1pt(x)f(pt)(x))dπt(x,y)\int_{B_R x B_R}(x-y)(\frac{1}{q_t(y)}\nabla f(q_t)(y)-\frac{1}{p_t(x)}\nabla f(p_t)(x)) d\pi_t(x,y)

と書け、[1]のTheorem 1の証明の式(7)から

=BR(f(qt))(y)(yϕt)dy+BR(f(pt))(x)(xϕt)dx=\int_{B_R}\nabla(f(q_t))(y) (y-\nabla \phi^*_t)dy+\int_{B_R}\nabla(f(p_t))(x) (x-\nabla \phi_t)dx

=BRf(qt)(Δϕtd)dyy=Rf(qt)(ϕt(y)yydy)BRf(pt)(Δϕtd)dxx=Rf(pt)(ϕt(x)xxdx)=\int_{B_R}f(q_t)(\Delta \phi^*_t-d)dy -\int_{|y|=R} f(q_t)(\nabla \phi^*_t(y)\frac{y}{|y|}dy) \int_{B_R}f(p_t)(\Delta \phi_t-d)dx -\int_{|x|=R} f(p_t)(\nabla \phi_t(x)\frac{x}{|x|}dx)
(RR\rightarrow\inftyとする)
=BRf(qt)(Δϕtd)dy+BRf(pt)(Δϕtd)dx=\int_{B_R}f(q_t)(\Delta \phi^*_t-d)dy +\int_{B_R}f(p_t)(\Delta \phi_t-d)dx

fとして恒等関数をとると1qt(y)f(qt)=logqt(y)\frac{1}{q_t(y)}\nabla f(q_t)=\log q_t(y)となり

Eπt[(xy)(logqt(y)logpt(x))]=Ept[Δϕt+Δϕt(ϕt)2d]E_{\pi_t}[(x-y)\cdot(\nabla \log q_t(y)-\nabla \log p_t(x))]=-E_{p_t}[\Delta \phi_t+\Delta\phi_t^*(\nabla \phi_t)-2d]

とかける(ϕt\phi_t^*ϕt\phi_tのconvex condugete)。この最後の式は[@BOLLEY20122430]のLemma3.2では2ϕ(x)\nabla^2\phi(x)が(n次元の)直交行列OOと正定値対角行列DDODOODO^*と書けることから

ϕ(ϕ(x))=x\nabla\phi^*(\nabla\phi(x))=x (最大化引数)

2ϕ(ϕ(x))2ϕ(x)=Id\nabla^2\phi^*(\nabla\phi(x))\nabla^2\phi(x)=Id

2ϕ(ϕ(x))=(2ϕ(x))1=OD1O\nabla^2\phi^*(\nabla\phi(x))=(\nabla^2\phi(x))^{-1}=OD^{-1}O^*

Δϕ(x)+Δϕ(x)(ϕ(x))2n=idi+i1di2n=i(di+1di)0\Delta\phi(x)+\Delta\phi^*(x)(\nabla \phi(x))-2n=\sum_i d_i + \sum_i \frac{1}{d_i}-2n =\sum_i (d_i +\frac{1}{d_i}) \ge 0

となることから負になると説明されている。(次元が登場するところが曲率次元条件に似ている)

Lemma 1

Eπt[(xy)(vqt(y)vpt(x))]W2(pt,qt){(Lf+Lsg2)W2(pt,qt)+g2b1/2}E_{\pi_t}[(x-y)(v_{q_t}(y)-v_{p_t}(x))]\le W_2(p_t,q_t)\{(L_f+L_sg^2)W_2(p_t,q_t)+g^2b^{1/2}\}

b:=Ept[logpt(x)sθ(x,t)2]b:=E_{p_t}[|\nabla \log p_t(x)-s_\theta (x,t)|^2]

証明解説

Fokker-Plank方程式から
Eπt[(xy)(vqt(y)vpt(x))]=Eπt[(xy)(f(y,t)f(x,t))]E_{\pi_t}[(x-y)(v_{q_t}(y)-v_{p_t}(x))]=E{\pi_t}[(x-y)(f(y,t)-f(x,t))]
+g2Eπt[(xy)(logpt(x)sθ(y,t))]+g22Eπt[(xy)(logqt(y)logpt(x))]+g^2 E_{\pi_t}[(x-y)(\nabla \log p_t(x)-s_\theta(y,t))]+\frac{g^2}{2}E{\pi_t}[(x-y)(\nabla \log q_t(y)-\nabla \log p_t(x))]

Lemma 2から3番目の項は0以下になる。最初の項はfのLipschitz性から

Eπt[(xy)(f(y,t)f(x,t))]LfEπt[xy2]=LfW22(pt,qt)E{\pi_t}[(x-y)(f(y,t)-f(x,t))]\le L_f E{\pi_t}[|x-y|^2]=L_f W_2^2 (p_t,q_t)

2番めの項 g2Eπt[(xy)(logpt(x)sθ(y,t))]g^2 E_{\pi_t}[(x-y)(\nabla \log p_t(x)-s_\theta(y,t))]

I1:=g2Eπt[(xy)(sθ(x,t)sθ(y,t))]I_1:=g^2 E_{\pi_t}[(x-y)(s_\theta(x,t)-s_\theta(y,t))]

I2:=g2Eπt[(xy)(logpt(x)sθ(x,t))]I_2:=g^2 E_{\pi_t}[(x-y)(\nabla \log p_t(x)-s_\theta(x,t))]

コーシー・シュワルツの不等式から

I2g2Eπt[xy2]12Eπt[logpt(x)sθ(x,t)2]12I_2\le g^2 E_{\pi_t}[|x-y|^2]^{\frac{1}{2}} E_{\pi_t}[|\nabla \log p_t(x)-s_\theta(x,t)|^2]^\frac{1}{2}

さらに

Eπt[logpt(x)sθ(x,t)2]=Ept[logpt(x)sθ(x,t)2]E_{\pi_t}[|\nabla \log p_t(x)-s_\theta(x,t)|^2]=E_{p_t}[|\nabla \log p_t(x)-s_\theta(x,t)|^2]

なので

I1+I2g(t)W2(pt,qt){LsW2(pt,qt)+b(t)12}I_1+I_2\le g(t)W_2(p_t,q_t)\{L_sW_2(p_t,q_t)+b(t)^\frac{1}{2}\}

Theorem 1(主定理)

W2(p0,q0)0Tg(t)2I(t)Ept[logpt(x)sθ(x,t)2]12dt+I(T)W2(pT,qT)W_2(p_0, q_0) \leq \int_0^T g(t)^2 I(t) \mathbb{E}_{p_t} \left[ \|\nabla \log p_t(x) - s_\theta(x,t)\|^2 \right]^{\frac{1}{2}} dt + I(T) W_2(p_T, q_T)

証明:

Wasserstein距離の微分(Theorem 5.24 of [4])とLemma
1の両辺からW2(pt,qt)W_2(p_t,q_t)を割って

ddtW2(pt,qt)(Lf+Lsg2)W2(pt,qt)+g2b1/2-\frac{d}{dt}W_2(p_t,q_t)\le (L_f+L_sg^2)W_2(p_t,q_t)+g^2b^{1/2}

ここで

I(t):=exp(0t(Lf+Lsg2)dr)I(t):=\exp(\int_0^t (L_f+L_sg^2)dr)

b(t):=Ept[logpt(x)sθ(x,t)2]b(t):=E_{p_t}[|\nabla \log p_t(x)-s_\theta(x,t)|^2]

と置くとddtI(t)=(Lf+Lsg2)I(t)\frac{d}{dt}I(t)=(L_f+L_sg^2)I(t)

ddt(I(t)W2(pt,qt))g2b1/2-\frac{d}{dt}(I(t)W_2(p_t,q_t)) \le g^2b^{1/2}

これを積分して

I(0)W2(p0,q0)I(T)W2(pT,qT)0Tg(t)2b(t)1/2I(t)dtI(0)W_2(p_0,q_0)-I(T)W_2(p_T,q_T) \le \int_0^T g(t)^2b(t)^{1/2}I(t)dt

I(0)=1とすると

W2(p0,q0)0Tg(t)2I(t)Ept[logpt(x)sθ(x,t)2]12dt+I(T)W2(pT,qT)W_2(p_0, q_0) \leq \int_0^T g(t)^2 I(t) \mathbb{E}_{p_t} \left[ \|\nabla \log p_t(x) - s_\theta(x,t)\|^2 \right]^{\frac{1}{2}} dt + I(T) W_2(p_T, q_T)

が結論付けられる。

注意点

Lipschiz定数の推定はNP hard問題らしい

感想

Wasserstein距離がscore関数の2乗誤差で抑えられるのは統一的な見方ができるのかもしれないが、一方でWasserstein score functionというものも提唱されていて[5]どういう関係があるのだろうか。

次元が絡むところから曲率次元条件と何らかの関係があるのかもしれない。

参考文献

[1] Fran¸cois Bolley and Jos´e A. Carrillo. Nonlinear diffusion: Geodesic convexity is equivalent
to wasserstein contraction, 2014.
https://arxiv.org/abs/1309.1932

[2] Fran¸cois Bolley, Ivan Gentil, and Arnaud Guillin. Convergence to equilibrium in wasserstein distance for fokker–planck equations. Journal of Functional Analysis, Vol. 263, No. 8,
pp. 2430–2457, 2012.
https://arxiv.org/abs/1110.3606

[3] Dohyun Kwon, Ying Fan, and Kangwook Lee. Score-based generative modeling secretly minimizes the wasserstein distance, 2022.
https://arxiv.org/abs/2212.06359

[4] Filippo Santambrogio. Optimal transport for applied mathematicians. calculus of variations, pdes and modeling. 2015.
https://link.springer.com/book/10.1007/978-3-319-20828-2

[5] Amari Shun-ichi and Matsuda Takeru. Wasserstein statistics in one-dimensional location
scale models, 2022.
https://link.springer.com/article/10.1007/s10463-021-00788-1#citeas
https://arxiv.org/abs/2007.11401

https://arxiv.org/abs/1910.11248
https://arxiv.org/abs/2307.12508

Discussion

ログインするとコメントできます