非負値テンソル因子分解(Non-negative Tensor Factorization, NTF)は、非負の要素を持つテンソルを、非負の要素を持つテンソルの積に分解する手法です。Non-negative Matrix Factorization (NMF)の拡張手法です。
以下のように定義されます。
\mathbf{F} \simeq \boldsymbol{U} \otimes \boldsymbol{V} \otimes \boldsymbol{W} \\
\mathbf{F} \in \mathbb{R}_{+}^{R \times S \times T}, \; \boldsymbol{U} \in \mathbb{R}_{+}^{R \times M}, \; \boldsymbol{V} \in \mathbb{R}_{+}^{S \times M}, \; \boldsymbol{W} \in \mathbb{R}_{+}^{T \times M} \\
M \lt \min(R, S, T) \\
ここで、\mathbf{F} は観測テンソルです。これを \boldsymbol{U}, \boldsymbol{V}, \boldsymbol{W} の3つのテンソルの積で近似します。
交互最小二乗法(Alternating Least Squares, ALS)によるNTFの解法
\mathbf{F} と \boldsymbol{U} \otimes \boldsymbol{V} \otimes \boldsymbol{W} のユークリッド距離を D_{EU} とします。
D_{EU}(\mathbf{U}, \mathbf{V}, \mathbf{W}) = \| \mathbf{F} - \mathbf{U} \otimes \mathbf{V} \otimes \mathbf{W} \|_F^2
導出
\mathbf{F} と \mathbf{U} \otimes \mathbf{V} \otimes \mathbf{W} のユークリッド距離を
D_{EU}(\mathbf{U}, \mathbf{V}, \mathbf{W}) = \| \mathbf{F} - \mathbf{U} \otimes \mathbf{V} \otimes \mathbf{W} \|_F^2
と定義した時の
\min_{\mathbf{U}, \mathbf{V}, \mathbf{W}} D_{EU}(\mathbf{U}, \mathbf{V}, \mathbf{W})
を考えます。
\begin{align*}
D_{EU}(\mathbf{U}, \mathbf{V}, \mathbf{W}) &= \left( F_{rst} - U_{rm} V_{sm} W_{tm} \right)\left( F_{rst} - U_{rp} V_{sp} W_{tp} \right) \\
&= F_{rst}F_{rst} -2 F_{rst} U_{rm} V_{sm} W_{tm} + (U_{rm} V_{sm} W_{tm})(U_{rp} V_{sp} W_{tp}) \\
\end{align*}
ここで、 \mathbf{U}, \mathbf{V}, \mathbf{W} に関する偏微分を考えます。
第一項の微分
\frac{\partial F_{rst}F_{rst}}{\partial U_{ij}} = \frac{\partial F_{rst}F_{rst}}{\partial V_{ij}} = \frac{\partial F_{rst}F_{rst}}{\partial W_{ij}} = 0
第二項の微分
\begin{align*}
-2\frac{\partial F_{rst}U_{rm}V_{sm}W_{tm}}{\partial U_{ij}}
&= -2F_{rst}V_{sm}W_{tm}\delta_{ri}\delta_{mj} = -2F_{ist}V_{sj}W_{tj} \\
-2\frac{\partial F_{rst}U_{rm}V_{sm}W_{tm}}{\partial V_{ij}}
&= -2F_{rst}U_{rm}W_{tm}\delta_{si}\delta_{mj} = -2F_{rit}U_{rj}W_{tj} \\
-2\frac{\partial F_{rst}U_{rm}V_{sm}W_{tm}}{\partial W_{ij}}
&= -2F_{rst}U_{rm}V_{sm}\delta_{ti}\delta_{mj} = -2F_{rsi}U_{rj}V_{sj} \\
\end{align*}
第三項の微分
\begin{align*}
\frac{\partial (U_{rm}V_{sm}W_{tm})(U_{rp}V_{sp}W_{tp})}{\partial U_{ij}}
&= \delta_{ri}\delta_{mj} V_{sm}W_{tm}U_{rp}V_{sp}W_{tp} + \delta_{ri}\delta_{pj}U_{rm}V_{sm}W_{tm}V_{sp}W_{tp} \\
&= V_{sj}W_{tj}U_{ip}V_{sp}W_{tp} + U_{im}V_{sm}W_{tm}V_{sj}W_{tj}
\end{align*}
ここで、m, p は和を取るためのダミーインデックスなので、m,pのどちらの計算についても同様の結果になります。したがって、
\begin{align*}
\frac{\partial (U_{rm}V_{sm}W_{tm})(U_{rp}V_{sp}W_{tp})}{\partial U_{ij}}
&= V_{sj}W_{tj}U_{im}V_{sm}W_{tm} + U_{im}V_{sm}W_{tm}V_{sj}W_{tj} \\
&= 2 V_{sj}W_{tj}U_{im}V_{sm}W_{tm}
\end{align*}
V_{ij}, W_{ij} についても同様にして、
\begin{align*}
\frac{\partial (U_{rm}V_{sm}W_{tm})(U_{rp}V_{sp}W_{tp})}{\partial V_{ij}}
&= 2 U_{rj}W_{tj}U_{rm}V_{im}W_{tm} \\
\frac{\partial (U_{rm}V_{sm}W_{tm})(U_{rp}V_{sp}W_{tp})}{\partial W_{ij}}
&= 2 U_{rj}V_{sj}U_{rm}V_{sm}W_{im}
\end{align*}
となります。第三項には微分した変数が右辺に含まれており、解析的に解くことが困難なことがわかります。そこで、第三項を補助関数(Auxiliary Function)で置き換えます。
補助関数の導出
Non-negative Matrix Factorization (NMF)の時と同様にします。
\begin{align*}
\lambda_{k} &= \frac{U_{rk}V_{sk}W_{tk}}{\sum_{m} U_{rm}V_{sm}W_{tm}} \\
\text{subject to: } &\lambda_{k} \geq 0,\;\sum_{k} \lambda_{k} = 1
\end{align*}
とおくと、第三項はx_k = U_{rk}V_{sk}W_{tk} / \lambda_{k} より(Jensenの不等式を参照)
\begin{align*}
\sum_{r,s,t}\left(\sum_{k} U_{rk}V_{sk}W_{tk}\right)^2 &\leq \sum_{r,s,t,k} \frac{U_{rk}^2V_{sk}^2W_{tk}^2}{\lambda_{k}} \\
&= \sum_{r,s,t,k} \frac{U_{rk}^2V_{sk}^2W_{tk}^2}{\frac{U_{rk}V_{sk}W_{tk}}{\sum_{m} U_{rm}V_{sm}W_{tm}}} \\
\end{align*}
となります。ここで、\lambda_{k} は定数となるはずですが、U_{rk}, V_{sk}, W_{tk} に依存する形になっています。そこで、交互最適化実行時の各最適化ステップ tにおける\mathbf{U}, \mathbf{V}, \mathbf{W} の値 \mathbf{U}^{(t)}, \mathbf{V}^{(t)}, \mathbf{W}^{(t)} を用いて、\lambda_{k} を固定します。
Uを更新する時(\partial D_{EU}(\mathbf{U}, \mathbf{V}, \mathbf{W}) / \partial \mathbf{U} = 0 のとき)
\begin{align*}
\sum_{r,s,t}\left(\sum_{k} U_{rk}V_{sk}W_{tk}\right)^2 &\leq \sum_{r,s,t,k} \frac{U_{rk}^2V_{sk}^2W_{tk}^2}{\lambda_{k}} \\
&= \sum_{r,s,t,k} \frac{U_{rk}^2V_{sk}^2W_{tk}^2}{\frac{U_{rk}^{(t)}V_{sk}W_{tk}}{\sum_{m} U_{rm}^{(t)}V_{sm}W_{tm}}} \\
&= \sum_{r,s,t,k} V_{sk}W_{tk}\frac{U_{rk}^2}{U_{rk}^{(t)}}\left(\sum_{m} V_{sm}W_{tm}U_{rm}^{(t)}\right) \\
\end{align*}
となります。
Vを更新する時(\partial D_{EU}(\mathbf{U}, \mathbf{V}, \mathbf{W}) / \partial \mathbf{V} = 0 のとき)
\begin{align*}
\sum_{r,s,t}\left(\sum_{k} U_{rk}V_{sk}W_{tk}\right)^2 &\leq \sum_{r,s,t,k} \frac{U_{rk}^2V_{sk}^2W_{tk}^2}{\lambda_{k}} \\
&= \sum_{r,s,t,k} \frac{U_{rk}^2V_{sk}^2W_{tk}^2}{\frac{U_{rk}V_{sk}^{(t)}W_{tk}}{\sum_{m} U_{rm}V_{sm}^{(t)}W_{tm}}} \\
&= \sum_{r,s,t,k} U_{rk}W_{tk}\frac{V_{sk}^2}{V_{sk}^{(t)}}\left(\sum_{m} U_{rm}W_{tm}V_{sm}^{(t)}\right) \\
\end{align*}
Wを更新する時(\partial D_{EU}(\mathbf{U}, \mathbf{V}, \mathbf{W}) / \partial \mathbf{W} = 0 のとき)
\begin{align*}
\sum_{r,s,t}\left(\sum_{k} U_{rk}V_{sk}W_{tk}\right)^2 &\leq \sum_{r,s,t,k} \frac{U_{rk}^2V_{sk}^2W_{tk}^2}{\lambda_{k}} \\
&= \sum_{r,s,t,k} \frac{U_{rk}^2V_{sk}^2W_{tk}^2}{\frac{U_{rk}V_{sk}W_{tk}^{(t)}}{\sum_{m} U_{rm}V_{sm}W_{tm}^{(t)}}} \\
&= \sum_{r,s,t,k} U_{rk}V_{sk}\frac{W_{tk}^2}{W_{tk}^{(t)}}\left(\sum_{m} U_{rm}V_{sm}W_{tm}^{(t)}\right) \\
\end{align*}
更新式の導出
以上より、\mathbf{U}, \mathbf{V}, \mathbf{W} を更新するための補助関数を目的関数全体で考えます。
Uの更新式の導出
もともとの目的関数 D_{EU}(\mathbf{U}, \mathbf{V}, \mathbf{W}) において、第三項を上記で導出した補助関数で置き換えます。
\begin{align*}
G(\mathbf{U}, \mathbf{U}^{(t)})_{rk} &= F_{rst}^{2} -2 \sum_{s, t} F_{rst} U_{rk} V_{sk} W_{tk} + \sum_{s, t} V_{sk}W_{tk}\frac{U_{rk}^2}{U_{rk}^{(t)}}\left(\sum_{m} V_{sm}W_{tm}U_{rm}^{(t)}\right) \\
\end{align*}
となります。この補助関数 G を最小化するために、各要素 U_{rk} について微分すると
\begin{align*}
\frac{\partial G(\mathbf{U}, \mathbf{U}^{(t)})}{\partial U_{rk}}
&= -2 \sum_{s, t} F_{rst}V_{sk}W_{tk} + 2 \sum_{s, t} V_{sk}W_{tk}\left(\sum_{m} V_{sm}W_{tm}U_{rm}^{(t)}\right) \frac{U_{rk}}{U_{rk}^{(t)}} \\
U_{rk}^{(t+1)} &= U_{rk}^{(t)} \circ \frac{\sum_{s, t} F_{rst}V_{sk}W_{tk}}{\sum_{s, t} V_{sk}W_{tk}\left(\sum_{m} V_{sm}W_{tm}U_{rm}^{(t)}\right)} \\
\end{align*}
となります。同様に、\mathbf{V}, \mathbf{W} についても更新式を導出することができます。
Vの更新式の導出
\begin{align*}
\frac{\partial G(\mathbf{V}, \mathbf{V}^{(t)})}{\partial V_{sk}}
&= -2 \sum_{r, t} F_{rst}U_{rk}W_{tk} + 2 \sum_{r, t} U_{rk}W_{tk}\left(\sum_{m} U_{rm}W_{tm}V_{sm}^{(t)}\right) \frac{V_{sk}}{V_{sk}^{(t)}} \\
V_{sk}^{(t+1)} &= V_{sk}^{(t)} \circ \frac{\sum_{r, t} F_{rst}U_{rk}W_{tk}}{\sum_{r, t} U_{rk}W_{tk}\left(\sum_{m} U_{rm}W_{tm}V_{sm}^{(t)}\right)} \\
\end{align*}
Wの更新式の導出
\begin{align*}
\frac{\partial G(\mathbf{W}, \mathbf{W}^{(t)})}{\partial W_{tk}}
&= -2 \sum_{r, s} F_{rst}U_{rk}V_{sk} + 2 \sum_{r, s} U_{rk}V_{sk}\left(\sum_{m} U_{rm}V_{sm}W_{tm}^{(t)}\right) \frac{W_{tk}}{W_{tk}^{(t)}} \\
W_{tk}^{(t+1)} &= W_{tk}^{(t)} \circ \frac{\sum_{r, s} F_{rst}U_{rk}V_{sk}}{\sum_{r, s} U_{rk}V_{sk}\left(\sum_{m} U_{rm}V_{sm}W_{tm}^{(t)}\right)} \\
\end{align*}
となります。
更新式の行列表現
以上で、更新式が導出できました。しかし、更新式が行列表現になっておらず、このまま実装すると numpy.einsum を使ったテンソル演算を行うことになります。これを避けるために、更新式を行列表現に変換します。
Uの更新式の行列表現
説明のため、 U の更新式を再掲します。
\begin{align*}
U_{rk}^{(t+1)} &= U_{rk}^{(t)} \circ \frac{\sum_{s, t} F_{rst}V_{sk}W_{tk}}{\sum_{s, t} V_{sk}W_{tk}\left(\sum_{m} V_{sm}W_{tm}U_{rm}^{(t)}\right)} \\
\end{align*}
分子の行列表現
分子の (r, k) 成分は以下で与えられます。
\begin{align*}
N_{rk} &= \sum_{s, t} F_{rst}V_{sk}W_{tk}
\end{align*}
ここで、
-
V_{sk}W_{tk} を Khatri-Rao積 \mathbf{V} \odot \mathbf{W} の (k, s) で表現すると ((s-1)T+t, k) 成分は V_{sk}W_{tk} に対応する。
-
F_{rst} をモード 1 でアンフォールドした \mathbf{F}_{(1)} を考えると (r,(s-1)T+t, t) 成分が F_{rst} に対応する。
ということを利用すると、N_{rk} は
\begin{align*}
N_{rk} &= \sum_{j=1}^{ST} (F_{(1)})_{rj} (\mathbf{V} \odot \mathbf{W})_{jk} \\
\text{subject to: } j &= (s-1)T+t \\
\end{align*}
すなわち、
\begin{align*}
\mathbf{N} &= \mathbf{F}_{(1)} (\mathbf{V} \odot \mathbf{W}) \\
\end{align*}
と書けます。
分母の行列表現
分母の (r, k) 成分は以下で与えられます。
\begin{align*}
D_{rk} &= \sum_{s, t} V_{sk}W_{tk}\left(\sum_{m} V_{sm}W_{tm}U_{rm}^{(t)}\right)
\end{align*}
ここで、 \sum_{m} V_{sm}W_{tm}U_{rm}^{(t)} は最適化ステップ t における再構成テンソル \widehat{\mathbf{F}}^{(t)} の (r, s, t) 成分 \widehat{F}_{rst}^{(t)} です。 V_{sk}W_{tk} については分子の時と同様にして、
\begin{align*}
D_{rk} &= \sum_{s, t} V_{sk}W_{tk}\widehat{F}_{rst}^{(t)} \\
&= \sum_{j=1}^{ST} \left(\widehat{F}_{(1)}^{(t)}\right)_{rj}(\mathbf{V} \odot \mathbf{W})_{jk} \\
\text{subject to: } j &= (s-1)T+t
\end{align*}
すなわち、
\begin{align*}
\mathbf{D} &= \widehat{\mathbf{F}}^{(t)}_{(1)} (\mathbf{V} \odot \mathbf{W}) \\
&= \mathbf{U}^{(t)}(\mathbf{V} \odot \mathbf{W})^{T}(\mathbf{V} \odot \mathbf{W}) \\
\end{align*}
と書けます。
以上より、Uの更新式は
\begin{align*}
\mathbf{U}^{(t+1)} &= \mathbf{U}^{(t)} \circ \frac{\mathbf{F}_{(1)} (\mathbf{V} \odot \mathbf{W})}{\mathbf{U}^{(t)}(\mathbf{V} \odot \mathbf{W})^{T}(\mathbf{V} \odot \mathbf{W})} \\
\end{align*}
と書けます。
Vの更新式の行列表現
Uと同様にして、Vの更新式は
\begin{align*}
\mathbf{V}^{(t+1)} &= \mathbf{V}^{(t)} \circ \frac{\mathbf{F}_{(2)} (\mathbf{W} \odot \mathbf{U})}{\mathbf{V}^{(t)}(\mathbf{W} \odot \mathbf{U})^{T}(\mathbf{W} \odot \mathbf{U})} \\
\end{align*}
と書ける。
Wの更新式の行列表現
Uと同様にして、Wの更新式は
\begin{align*}
\mathbf{W}^{(t+1)} &= \mathbf{W}^{(t)} \circ \frac{\mathbf{F}_{(3)} (\mathbf{U} \odot \mathbf{V})}{\mathbf{W}^{(t)}(\mathbf{U} \odot \mathbf{V})^{T}(\mathbf{U} \odot \mathbf{V})} \\
\end{align*}
と書ける。
実装例
実装は以下のようになります。
from __future__ import annotations
from itertools import product
import numpy as np
from scipy.linalg import khatri_rao
import matplotlib.pyplot as plt
class NTF:
def __init__(
self,
n_components: int,
max_iter: int = 100,
eps: float = 1e-8,
random_state: int | None = None,
) -> None:
self.n_components = n_components
self.max_iter = max_iter
self.eps = eps
self.U: np.ndarray[float] | None = None
self.V: np.ndarray[float] | None = None
self.W: np.ndarray[float] | None = None
self.random_state = random_state
def fit(self, X: np.ndarray[float]) -> NTF:
np.random.seed(self.random_state)
R, S, T = X.shape
U = np.random.rand(R, self.n_components)
V = np.random.rand(S, self.n_components)
W = np.random.rand(T, self.n_components)
for _ in range(self.max_iter):
X1 = X.reshape(R, S*T)
VW_kr = khatri_rao(V, W)
U = U * (X1 @ VW_kr) / (U @ VW_kr.T @ VW_kr + self.eps)
X2 = np.transpose(X, (1, 2, 0)).reshape(S, T*R)
WU_kr = khatri_rao(W, U)
V = V * (X2 @ WU_kr) / (V @ WU_kr.T @ WU_kr + self.eps)
X3 = np.transpose(X, (2, 0, 1)).reshape(T, R*S)
UV_kr = khatri_rao(U, V)
W = W * (X3 @ UV_kr) / (W @ UV_kr.T @ UV_kr + self.eps)
self.U = U
self.V = V
self.W = W
return self
xx = np.linspace(1, 10, 10, endpoint=True)
X = np.array(list(product(xx, xx, xx)))
G = X.reshape(10,10,30)
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=X[:, 0], cmap='viridis')
plt.show()
R, S, T = 10, 10, 30
xx = np.linspace(1, 10, 10, endpoint=True)
X = np.array(list(product(xx, xx, xx)))
G = X.reshape(10,10,30)
ntf = NTF(n_components=3, max_iter=1000, eps=1e-8, random_state=42)
Gmax = np.max(G)
Gmin = np.min(G)
G = (G - Gmin) / (Gmax - Gmin)
ntf.fit(G)
Ghat = ntf.U @ khatri_rao(ntf.V, ntf.W).T
Ghat = Ghat * (Gmax - Gmin) + Gmin
Xhat = Ghat.reshape(-1, 3)
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(Xhat[:, 0], Xhat[:, 1], Xhat[:, 2], c=Xhat[:, 0], cmap='viridis')
plt.show()

元テンソルの3Dプロット

NTFによる再構成の3Dプロット
ポアソン分布の尤度最大化によるNTFの解法
NMFの場合と同様に、観測テンソル \mathbf{F} をポアソン分布に従うカウントデータとみなし、その平均テンソル(レート)\boldsymbol{\Lambda} を CP 分解 \boldsymbol{\Lambda} \simeq \sum_{m=1}^{M} \boldsymbol{u}_{:m} \circ \boldsymbol{v}_{:m} \circ \boldsymbol{w}_{:m} でモデル化します。
目的関数(KLダイバージェンス)
負のポアソン対数尤度はテンソル版の一般化 KL ダイバージェンスに等価で、
\begin{align*}
J
&= - \log \mathcal{L}(\boldsymbol{U}, \boldsymbol{V}, \boldsymbol{W}) \\
&= \sum_{r,s,t} \left( -\mathbf{F}_{rst} \log \boldsymbol{\Lambda}_{rst} + \boldsymbol{\Lambda}_{rst} + \log(\mathbf{F}_{rst}!) \right),
\end{align*}
ここで \boldsymbol{\Lambda}_{rst} = \sum_{m=1}^{M} U_{r m} V_{s m} W_{t m} です。
以下、要素ごとの積を \circ、要素ごとの除算を \oslash とします。また、\odot は Khatri–Rao 積(列ごとの Kronecker 積)、\mathbf{F}_{(n)} はモード n での展開行列(unfolding)を表します。
モード展開の形での再構成は
\begin{align*}
\boldsymbol{\Lambda}_{(1)} = \boldsymbol{U} (\boldsymbol{W} \odot \boldsymbol{V})^{\top} \\
\boldsymbol{\Lambda}_{(2)} = \boldsymbol{V} (\boldsymbol{W} \odot \boldsymbol{U})^{\top} \\
\boldsymbol{\Lambda}_{(3)} = \boldsymbol{W} (\boldsymbol{V} \odot \boldsymbol{U})^{\top}
\end{align*}
乗法更新式(Multiplicative Update Rules)
NMFの場合と同様、勾配を正負の項に分解して乗法更新を行うことで、非負性を保ちつつ (J) を減少させる更新が得られます。
偏微分の導出(要素形)
\boldsymbol{\Lambda}_{rst} = \sum_{m} U_{rm} V_{sm} W_{tm} とします。まず U_{rm} に関して、
\frac{\partial J}{\partial U_{rm}}
= \sum_{s,t} \left( - \frac{\mathbf{F}_{rst}}{\boldsymbol{\Lambda}_{rst}} \cdot \frac{\partial \boldsymbol{\Lambda}_{rst}}{\partial U_{rm}} + \frac{\partial \boldsymbol{\Lambda}_{rst}}{\partial U_{rm}} \right)
= \sum_{s,t} \left(1 - \frac{\mathbf{F}_{rst}}{\boldsymbol{\Lambda}_{rst}}\right) V_{sm} W_{tm}.
同様に、V_{sm} と W_{tm} に関しては、
\frac{\partial J}{\partial V_{sm}} = \sum_{r,t} \left(1 - \frac{\mathbf{F}_{rst}}{\boldsymbol{\Lambda}_{rst}}\right) U_{rm} W_{tm},\quad
\frac{\partial J}{\partial W_{tm}} = \sum_{r,s} \left(1 - \frac{\mathbf{F}_{rst}}{\boldsymbol{\Lambda}_{rst}}\right) U_{rm} V_{sm}.
これらを「正の項 − 負の項」に分解すると、例えば U では
\frac{\partial J}{\partial U_{rm}} = \underbrace{\sum_{s,t} V_{sm} W_{tm}}_{\text{正}}\; -\; \underbrace{\sum_{s,t} V_{sm} W_{tm} \; \frac{\mathbf{F}_{rst}}{\boldsymbol{\Lambda}_{rst}}}_{\text{負}}.
乗法更新は
\begin{align*}
U_{rm} \leftarrow U_{rm} \times \frac{\text{負}}{\text{正}}, \\
V_{sm} \leftarrow V_{sm} \times \frac{\text{負}}{\text{正}}, \\
W_{tm} \leftarrow W_{tm} \times \frac{\text{負}}{\text{正}}
\end{align*}
により得られます。
要素表現
\begin{align*}
U_{r m}^{(t+1)} &= U_{r m}^{(t)} \;\circ\; \frac{\displaystyle \sum_{s,t} V_{s m} W_{t m} \; \frac{\mathbf{F}_{rst}}{\boldsymbol{\Lambda}_{rst}}}{\displaystyle \sum_{s,t} V_{s m} W_{t m}}, \\
V_{s m}^{(t+1)} &= V_{s m}^{(t)} \;\circ\; \frac{\displaystyle \sum_{r,t} U_{r m} W_{t m} \; \frac{\mathbf{F}_{rst}}{\boldsymbol{\Lambda}_{rst}}}{\displaystyle \sum_{r,t} U_{r m} W_{t m}}, \\
W_{t m}^{(t+1)} &= W_{t m}^{(t)} \;\circ\; \frac{\displaystyle \sum_{r,s} U_{r m} V_{s m} \; \frac{\mathbf{F}_{rst}}{\boldsymbol{\Lambda}_{rst}}}{\displaystyle \sum_{r,s} U_{r m} V_{s m}}.
\end{align*}
モード展開・Khatri–Rao による行列表現
\mathbf{1}_R, \mathbf{1}_S, \mathbf{1}_T は全1ベクトル、\mathbf{1}_{ST} 等は対応する長さの全1ベクトル。
\begin{align*}
\boldsymbol{U}^{(t+1)}
&= \boldsymbol{U}^{(t)} \;\circ\;
\frac{\left( \mathbf{F}_{(1)} \;\oslash\; \boldsymbol{\Lambda}_{(1)} \right) \; (\boldsymbol{W} \odot \boldsymbol{V})}{\;\mathbf{1}_R\, \left( (\boldsymbol{W} \odot \boldsymbol{V})^{\top} \mathbf{1}_{ST} \right)^{\top}}, \\
\boldsymbol{V}^{(t+1)}
&= \boldsymbol{V}^{(t)} \;\circ\;
\frac{\left( \mathbf{F}_{(2)} \;\oslash\; \boldsymbol{\Lambda}_{(2)} \right) \; (\boldsymbol{W} \odot \boldsymbol{U})}{\;\mathbf{1}_S\, \left( (\boldsymbol{W} \odot \boldsymbol{U})^{\top} \mathbf{1}_{RT} \right)^{\top}}, \\
\boldsymbol{W}^{(t+1)}
&= \boldsymbol{W}^{(t)} \;\circ\;
\frac{\left( \mathbf{F}_{(3)} \;\oslash\; \boldsymbol{\Lambda}_{(3)} \right) \; (\boldsymbol{V} \odot \boldsymbol{U})}{\;\mathbf{1}_T\, \left( (\boldsymbol{V} \odot \boldsymbol{U})^{\top} \mathbf{1}_{RS} \right)^{\top}}.
\end{align*}
各分子は観測と再構成の要素比(\mathbf{F}_{(n)} \oslash \boldsymbol{\Lambda}_{(n)})を Khatri–Rao 積で重み付け集約したもので、分母は対応する正規化係数(Khatri–Rao 行列の列和)です。形状はそれぞれ \boldsymbol{U}\!: R\times M, \boldsymbol{V}\!: S\times M, \boldsymbol{W}\!: T\times M に一致します。
乗法更新の正当化
NMFと同様、x \mapsto -f\log x に対する接線不等式(Jensen/Youngの不等式)を用いるか、あるいは KL の補助関数(auxiliary function)を構成して、上界を最小化する MM(Majorization–Minimization)により、上記の比形式の更新が J を非増加にすることが示せます。要点は、比 \mathbf{F}/\boldsymbol{\Lambda} による重み付き加重平均と、Khatri–Rao 行列の列和による正規化が、凸上界の最小化解として現れる点にあります。
実装例
from __future__ import annotations
import numpy as np
from scipy.linalg import khatri_rao
class NTF:
def __init__(
self,
n_components: int,
max_iter: int = 100,
eps: float = 1e-8,
random_state: int | None = None,
) -> None:
self.n_components = n_components
self.max_iter = max_iter
self.eps = eps
self.U: np.ndarray[float] | None = None
self.V: np.ndarray[float] | None = None
self.W: np.ndarray[float] | None = None
self.random_state = random_state
def fit(self, X: np.ndarray[float]) -> NTF:
np.random.seed(self.random_state)
R, S, T = X.shape
U = np.random.rand(R, self.n_components)
V = np.random.rand(S, self.n_components)
W = np.random.rand(T, self.n_components)
for _ in range(self.max_iter):
VW_kr = khatri_rao(V, W)
X1 = X.reshape(R, S*T)
Lambda1 = (U @ VW_kr.T).reshape(R, S*T) + self.eps
numerator = (X1 / Lambda1) @ VW_kr
denominator = np.ones((R, 1)) @ (VW_kr.T @ np.ones((S*T, 1))).T
U = U * numerator / (denominator + self.eps)
WU_kr = khatri_rao(W, U)
X2 = np.transpose(X, (1, 2, 0)).reshape(S, T*R)
Lambda2 = (V @ WU_kr.T).reshape(S, T*R) + self.eps
numerator = (X2 / Lambda2) @ WU_kr
denominator = np.ones((S, 1)) @ (WU_kr.T @ np.ones((T*R, 1))).T
V = V * numerator / (denominator + self.eps)
UV_kr = khatri_rao(U, V)
X3 = np.transpose(X, (2, 0, 1)).reshape(T, R*S)
Lambda3 = (W @ UV_kr.T).reshape(T, R*S) + self.eps
numerator = (X3 / Lambda3) @ UV_kr
denominator = np.ones((T, 1)) @ (UV_kr.T @ np.ones((R*S, 1))).T
W = W * numerator / (denominator + self.eps)
self.U = U
self.V = V
self.W = W
return self
参考
-
https://qiita.com/K_Noguchi/items/35d8ff52d3dc87bb61d0
- こちらの記事ではeinsumを使ったテンソル演算が行われています
Discussion