🙌

非負値テンソル因子分解(NTF)を理解したい

に公開

非負値テンソル因子分解(Non-negative Tensor Factorization, NTF)は、非負の要素を持つテンソルを、非負の要素を持つテンソルの積に分解する手法です。Non-negative Matrix Factorization (NMF)の拡張手法です。

以下のように定義されます。

\mathbf{F} \simeq \boldsymbol{U} \otimes \boldsymbol{V} \otimes \boldsymbol{W} \\ \mathbf{F} \in \mathbb{R}_{+}^{R \times S \times T}, \; \boldsymbol{U} \in \mathbb{R}_{+}^{R \times M}, \; \boldsymbol{V} \in \mathbb{R}_{+}^{S \times M}, \; \boldsymbol{W} \in \mathbb{R}_{+}^{T \times M} \\ M \lt \min(R, S, T) \\

ここで、\mathbf{F} は観測テンソルです。これを \boldsymbol{U}, \boldsymbol{V}, \boldsymbol{W} の3つのテンソルの積で近似します。

交互最小二乗法(Alternating Least Squares, ALS)によるNTFの解法

\mathbf{F}\boldsymbol{U} \otimes \boldsymbol{V} \otimes \boldsymbol{W} のユークリッド距離を D_{EU} とします。

D_{EU}(\mathbf{U}, \mathbf{V}, \mathbf{W}) = \| \mathbf{F} - \mathbf{U} \otimes \mathbf{V} \otimes \mathbf{W} \|_F^2

導出

\mathbf{F}\mathbf{U} \otimes \mathbf{V} \otimes \mathbf{W} のユークリッド距離を

D_{EU}(\mathbf{U}, \mathbf{V}, \mathbf{W}) = \| \mathbf{F} - \mathbf{U} \otimes \mathbf{V} \otimes \mathbf{W} \|_F^2

と定義した時の

\min_{\mathbf{U}, \mathbf{V}, \mathbf{W}} D_{EU}(\mathbf{U}, \mathbf{V}, \mathbf{W})

を考えます。

\begin{align*} D_{EU}(\mathbf{U}, \mathbf{V}, \mathbf{W}) &= \left( F_{rst} - U_{rm} V_{sm} W_{tm} \right)\left( F_{rst} - U_{rp} V_{sp} W_{tp} \right) \\ &= F_{rst}F_{rst} -2 F_{rst} U_{rm} V_{sm} W_{tm} + (U_{rm} V_{sm} W_{tm})(U_{rp} V_{sp} W_{tp}) \\ \end{align*}

ここで、 \mathbf{U}, \mathbf{V}, \mathbf{W} に関する偏微分を考えます。

第一項の微分

\frac{\partial F_{rst}F_{rst}}{\partial U_{ij}} = \frac{\partial F_{rst}F_{rst}}{\partial V_{ij}} = \frac{\partial F_{rst}F_{rst}}{\partial W_{ij}} = 0

第二項の微分

\begin{align*} -2\frac{\partial F_{rst}U_{rm}V_{sm}W_{tm}}{\partial U_{ij}} &= -2F_{rst}V_{sm}W_{tm}\delta_{ri}\delta_{mj} = -2F_{ist}V_{sj}W_{tj} \\ -2\frac{\partial F_{rst}U_{rm}V_{sm}W_{tm}}{\partial V_{ij}} &= -2F_{rst}U_{rm}W_{tm}\delta_{si}\delta_{mj} = -2F_{rit}U_{rj}W_{tj} \\ -2\frac{\partial F_{rst}U_{rm}V_{sm}W_{tm}}{\partial W_{ij}} &= -2F_{rst}U_{rm}V_{sm}\delta_{ti}\delta_{mj} = -2F_{rsi}U_{rj}V_{sj} \\ \end{align*}

第三項の微分

\begin{align*} \frac{\partial (U_{rm}V_{sm}W_{tm})(U_{rp}V_{sp}W_{tp})}{\partial U_{ij}} &= \delta_{ri}\delta_{mj} V_{sm}W_{tm}U_{rp}V_{sp}W_{tp} + \delta_{ri}\delta_{pj}U_{rm}V_{sm}W_{tm}V_{sp}W_{tp} \\ &= V_{sj}W_{tj}U_{ip}V_{sp}W_{tp} + U_{im}V_{sm}W_{tm}V_{sj}W_{tj} \end{align*}

ここで、m, p は和を取るためのダミーインデックスなので、m,pのどちらの計算についても同様の結果になります。したがって、

\begin{align*} \frac{\partial (U_{rm}V_{sm}W_{tm})(U_{rp}V_{sp}W_{tp})}{\partial U_{ij}} &= V_{sj}W_{tj}U_{im}V_{sm}W_{tm} + U_{im}V_{sm}W_{tm}V_{sj}W_{tj} \\ &= 2 V_{sj}W_{tj}U_{im}V_{sm}W_{tm} \end{align*}

V_{ij}, W_{ij} についても同様にして、

\begin{align*} \frac{\partial (U_{rm}V_{sm}W_{tm})(U_{rp}V_{sp}W_{tp})}{\partial V_{ij}} &= 2 U_{rj}W_{tj}U_{rm}V_{im}W_{tm} \\ \frac{\partial (U_{rm}V_{sm}W_{tm})(U_{rp}V_{sp}W_{tp})}{\partial W_{ij}} &= 2 U_{rj}V_{sj}U_{rm}V_{sm}W_{im} \end{align*}

となります。第三項には微分した変数が右辺に含まれており、解析的に解くことが困難なことがわかります。そこで、第三項を補助関数(Auxiliary Function)で置き換えます。

補助関数の導出

Non-negative Matrix Factorization (NMF)の時と同様にします。

\begin{align*} \lambda_{k} &= \frac{U_{rk}V_{sk}W_{tk}}{\sum_{m} U_{rm}V_{sm}W_{tm}} \\ \text{subject to: } &\lambda_{k} \geq 0,\;\sum_{k} \lambda_{k} = 1 \end{align*}

とおくと、第三項はx_k = U_{rk}V_{sk}W_{tk} / \lambda_{k} より(Jensenの不等式を参照)

\begin{align*} \sum_{r,s,t}\left(\sum_{k} U_{rk}V_{sk}W_{tk}\right)^2 &\leq \sum_{r,s,t,k} \frac{U_{rk}^2V_{sk}^2W_{tk}^2}{\lambda_{k}} \\ &= \sum_{r,s,t,k} \frac{U_{rk}^2V_{sk}^2W_{tk}^2}{\frac{U_{rk}V_{sk}W_{tk}}{\sum_{m} U_{rm}V_{sm}W_{tm}}} \\ \end{align*}

となります。ここで、\lambda_{k} は定数となるはずですが、U_{rk}, V_{sk}, W_{tk} に依存する形になっています。そこで、交互最適化実行時の各最適化ステップ tにおける\mathbf{U}, \mathbf{V}, \mathbf{W} の値 \mathbf{U}^{(t)}, \mathbf{V}^{(t)}, \mathbf{W}^{(t)} を用いて、\lambda_{k} を固定します。

Uを更新する時(\partial D_{EU}(\mathbf{U}, \mathbf{V}, \mathbf{W}) / \partial \mathbf{U} = 0 のとき)

\begin{align*} \sum_{r,s,t}\left(\sum_{k} U_{rk}V_{sk}W_{tk}\right)^2 &\leq \sum_{r,s,t,k} \frac{U_{rk}^2V_{sk}^2W_{tk}^2}{\lambda_{k}} \\ &= \sum_{r,s,t,k} \frac{U_{rk}^2V_{sk}^2W_{tk}^2}{\frac{U_{rk}^{(t)}V_{sk}W_{tk}}{\sum_{m} U_{rm}^{(t)}V_{sm}W_{tm}}} \\ &= \sum_{r,s,t,k} V_{sk}W_{tk}\frac{U_{rk}^2}{U_{rk}^{(t)}}\left(\sum_{m} V_{sm}W_{tm}U_{rm}^{(t)}\right) \\ \end{align*}

となります。

Vを更新する時(\partial D_{EU}(\mathbf{U}, \mathbf{V}, \mathbf{W}) / \partial \mathbf{V} = 0 のとき)

\begin{align*} \sum_{r,s,t}\left(\sum_{k} U_{rk}V_{sk}W_{tk}\right)^2 &\leq \sum_{r,s,t,k} \frac{U_{rk}^2V_{sk}^2W_{tk}^2}{\lambda_{k}} \\ &= \sum_{r,s,t,k} \frac{U_{rk}^2V_{sk}^2W_{tk}^2}{\frac{U_{rk}V_{sk}^{(t)}W_{tk}}{\sum_{m} U_{rm}V_{sm}^{(t)}W_{tm}}} \\ &= \sum_{r,s,t,k} U_{rk}W_{tk}\frac{V_{sk}^2}{V_{sk}^{(t)}}\left(\sum_{m} U_{rm}W_{tm}V_{sm}^{(t)}\right) \\ \end{align*}

Wを更新する時(\partial D_{EU}(\mathbf{U}, \mathbf{V}, \mathbf{W}) / \partial \mathbf{W} = 0 のとき)

\begin{align*} \sum_{r,s,t}\left(\sum_{k} U_{rk}V_{sk}W_{tk}\right)^2 &\leq \sum_{r,s,t,k} \frac{U_{rk}^2V_{sk}^2W_{tk}^2}{\lambda_{k}} \\ &= \sum_{r,s,t,k} \frac{U_{rk}^2V_{sk}^2W_{tk}^2}{\frac{U_{rk}V_{sk}W_{tk}^{(t)}}{\sum_{m} U_{rm}V_{sm}W_{tm}^{(t)}}} \\ &= \sum_{r,s,t,k} U_{rk}V_{sk}\frac{W_{tk}^2}{W_{tk}^{(t)}}\left(\sum_{m} U_{rm}V_{sm}W_{tm}^{(t)}\right) \\ \end{align*}

更新式の導出

以上より、\mathbf{U}, \mathbf{V}, \mathbf{W} を更新するための補助関数を目的関数全体で考えます。

Uの更新式の導出

もともとの目的関数 D_{EU}(\mathbf{U}, \mathbf{V}, \mathbf{W}) において、第三項を上記で導出した補助関数で置き換えます。

\begin{align*} G(\mathbf{U}, \mathbf{U}^{(t)})_{rk} &= F_{rst}^{2} -2 \sum_{s, t} F_{rst} U_{rk} V_{sk} W_{tk} + \sum_{s, t} V_{sk}W_{tk}\frac{U_{rk}^2}{U_{rk}^{(t)}}\left(\sum_{m} V_{sm}W_{tm}U_{rm}^{(t)}\right) \\ \end{align*}

となります。この補助関数 G を最小化するために、各要素 U_{rk} について微分すると

\begin{align*} \frac{\partial G(\mathbf{U}, \mathbf{U}^{(t)})}{\partial U_{rk}} &= -2 \sum_{s, t} F_{rst}V_{sk}W_{tk} + 2 \sum_{s, t} V_{sk}W_{tk}\left(\sum_{m} V_{sm}W_{tm}U_{rm}^{(t)}\right) \frac{U_{rk}}{U_{rk}^{(t)}} \\ U_{rk}^{(t+1)} &= U_{rk}^{(t)} \circ \frac{\sum_{s, t} F_{rst}V_{sk}W_{tk}}{\sum_{s, t} V_{sk}W_{tk}\left(\sum_{m} V_{sm}W_{tm}U_{rm}^{(t)}\right)} \\ \end{align*}

となります。同様に、\mathbf{V}, \mathbf{W} についても更新式を導出することができます。

Vの更新式の導出

\begin{align*} \frac{\partial G(\mathbf{V}, \mathbf{V}^{(t)})}{\partial V_{sk}} &= -2 \sum_{r, t} F_{rst}U_{rk}W_{tk} + 2 \sum_{r, t} U_{rk}W_{tk}\left(\sum_{m} U_{rm}W_{tm}V_{sm}^{(t)}\right) \frac{V_{sk}}{V_{sk}^{(t)}} \\ V_{sk}^{(t+1)} &= V_{sk}^{(t)} \circ \frac{\sum_{r, t} F_{rst}U_{rk}W_{tk}}{\sum_{r, t} U_{rk}W_{tk}\left(\sum_{m} U_{rm}W_{tm}V_{sm}^{(t)}\right)} \\ \end{align*}

Wの更新式の導出

\begin{align*} \frac{\partial G(\mathbf{W}, \mathbf{W}^{(t)})}{\partial W_{tk}} &= -2 \sum_{r, s} F_{rst}U_{rk}V_{sk} + 2 \sum_{r, s} U_{rk}V_{sk}\left(\sum_{m} U_{rm}V_{sm}W_{tm}^{(t)}\right) \frac{W_{tk}}{W_{tk}^{(t)}} \\ W_{tk}^{(t+1)} &= W_{tk}^{(t)} \circ \frac{\sum_{r, s} F_{rst}U_{rk}V_{sk}}{\sum_{r, s} U_{rk}V_{sk}\left(\sum_{m} U_{rm}V_{sm}W_{tm}^{(t)}\right)} \\ \end{align*}

となります。

更新式の行列表現

以上で、更新式が導出できました。しかし、更新式が行列表現になっておらず、このまま実装すると numpy.einsum を使ったテンソル演算を行うことになります。これを避けるために、更新式を行列表現に変換します。

Uの更新式の行列表現

説明のため、 U の更新式を再掲します。

\begin{align*} U_{rk}^{(t+1)} &= U_{rk}^{(t)} \circ \frac{\sum_{s, t} F_{rst}V_{sk}W_{tk}}{\sum_{s, t} V_{sk}W_{tk}\left(\sum_{m} V_{sm}W_{tm}U_{rm}^{(t)}\right)} \\ \end{align*}

分子の行列表現

分子の (r, k) 成分は以下で与えられます。

\begin{align*} N_{rk} &= \sum_{s, t} F_{rst}V_{sk}W_{tk} \end{align*}

ここで、

  1. V_{sk}W_{tk} を Khatri-Rao積 \mathbf{V} \odot \mathbf{W}(k, s) で表現すると ((s-1)T+t, k) 成分は V_{sk}W_{tk} に対応する。
  2. F_{rst} をモード 1 でアンフォールドした \mathbf{F}_{(1)} を考えると (r,(s-1)T+t, t) 成分が F_{rst} に対応する。

ということを利用すると、N_{rk}

\begin{align*} N_{rk} &= \sum_{j=1}^{ST} (F_{(1)})_{rj} (\mathbf{V} \odot \mathbf{W})_{jk} \\ \text{subject to: } j &= (s-1)T+t \\ \end{align*}

すなわち、

\begin{align*} \mathbf{N} &= \mathbf{F}_{(1)} (\mathbf{V} \odot \mathbf{W}) \\ \end{align*}

と書けます。

分母の行列表現

分母の (r, k) 成分は以下で与えられます。

\begin{align*} D_{rk} &= \sum_{s, t} V_{sk}W_{tk}\left(\sum_{m} V_{sm}W_{tm}U_{rm}^{(t)}\right) \end{align*}

ここで、 \sum_{m} V_{sm}W_{tm}U_{rm}^{(t)} は最適化ステップ t における再構成テンソル \widehat{\mathbf{F}}^{(t)}(r, s, t) 成分 \widehat{F}_{rst}^{(t)} です。 V_{sk}W_{tk} については分子の時と同様にして、

\begin{align*} D_{rk} &= \sum_{s, t} V_{sk}W_{tk}\widehat{F}_{rst}^{(t)} \\ &= \sum_{j=1}^{ST} \left(\widehat{F}_{(1)}^{(t)}\right)_{rj}(\mathbf{V} \odot \mathbf{W})_{jk} \\ \text{subject to: } j &= (s-1)T+t \end{align*}

すなわち、

\begin{align*} \mathbf{D} &= \widehat{\mathbf{F}}^{(t)}_{(1)} (\mathbf{V} \odot \mathbf{W}) \\ &= \mathbf{U}^{(t)}(\mathbf{V} \odot \mathbf{W})^{T}(\mathbf{V} \odot \mathbf{W}) \\ \end{align*}

と書けます。

以上より、Uの更新式は

\begin{align*} \mathbf{U}^{(t+1)} &= \mathbf{U}^{(t)} \circ \frac{\mathbf{F}_{(1)} (\mathbf{V} \odot \mathbf{W})}{\mathbf{U}^{(t)}(\mathbf{V} \odot \mathbf{W})^{T}(\mathbf{V} \odot \mathbf{W})} \\ \end{align*}

と書けます。

Vの更新式の行列表現

Uと同様にして、Vの更新式は

\begin{align*} \mathbf{V}^{(t+1)} &= \mathbf{V}^{(t)} \circ \frac{\mathbf{F}_{(2)} (\mathbf{W} \odot \mathbf{U})}{\mathbf{V}^{(t)}(\mathbf{W} \odot \mathbf{U})^{T}(\mathbf{W} \odot \mathbf{U})} \\ \end{align*}

と書ける。

Wの更新式の行列表現

Uと同様にして、Wの更新式は

\begin{align*} \mathbf{W}^{(t+1)} &= \mathbf{W}^{(t)} \circ \frac{\mathbf{F}_{(3)} (\mathbf{U} \odot \mathbf{V})}{\mathbf{W}^{(t)}(\mathbf{U} \odot \mathbf{V})^{T}(\mathbf{U} \odot \mathbf{V})} \\ \end{align*}

と書ける。

実装例

実装は以下のようになります。

from __future__ import annotations

from itertools import product

import numpy as np
from scipy.linalg import khatri_rao
import matplotlib.pyplot as plt


class NTF:
    def __init__(
        self,
        n_components: int,
        max_iter: int = 100,
        eps: float = 1e-8,
        random_state: int | None = None,
    ) -> None:
        self.n_components = n_components
        self.max_iter = max_iter
        self.eps = eps
        self.U: np.ndarray[float] | None = None
        self.V: np.ndarray[float] | None = None
        self.W: np.ndarray[float] | None = None
        self.random_state = random_state

    def fit(self, X: np.ndarray[float]) -> NTF:
        np.random.seed(self.random_state)
        R, S, T = X.shape
        U = np.random.rand(R, self.n_components)
        V = np.random.rand(S, self.n_components)
        W = np.random.rand(T, self.n_components)
        
        for _ in range(self.max_iter):
            X1 = X.reshape(R, S*T)
            VW_kr = khatri_rao(V, W)
            U = U * (X1 @ VW_kr) / (U @ VW_kr.T @ VW_kr + self.eps)

            X2 = np.transpose(X, (1, 2, 0)).reshape(S, T*R)
            WU_kr = khatri_rao(W, U)
            V = V * (X2 @ WU_kr) / (V @ WU_kr.T @ WU_kr + self.eps)

            X3 = np.transpose(X, (2, 0, 1)).reshape(T, R*S)
            UV_kr = khatri_rao(U, V)
            W = W * (X3 @ UV_kr) / (W @ UV_kr.T @ UV_kr + self.eps)
        self.U = U
        self.V = V
        self.W = W
            
        return self

# データ準備
xx = np.linspace(1, 10, 10, endpoint=True)
X = np.array(list(product(xx, xx, xx)))
G = X.reshape(10,10,30)

fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=X[:, 0], cmap='viridis')
plt.show()

# NTFによる再構成
R, S, T = 10, 10, 30
xx = np.linspace(1, 10, 10, endpoint=True)
X = np.array(list(product(xx, xx, xx)))
G = X.reshape(10,10,30)

ntf = NTF(n_components=3, max_iter=1000, eps=1e-8, random_state=42)
Gmax = np.max(G)
Gmin = np.min(G)
G = (G - Gmin) / (Gmax - Gmin)
ntf.fit(G)
Ghat = ntf.U @ khatri_rao(ntf.V, ntf.W).T
Ghat = Ghat * (Gmax - Gmin) + Gmin
Xhat = Ghat.reshape(-1, 3)

fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(Xhat[:, 0], Xhat[:, 1], Xhat[:, 2], c=Xhat[:, 0], cmap='viridis')
plt.show()


元テンソルの3Dプロット


NTFによる再構成の3Dプロット

ポアソン分布の尤度最大化によるNTFの解法

NMFの場合と同様に、観測テンソル \mathbf{F} をポアソン分布に従うカウントデータとみなし、その平均テンソル(レート)\boldsymbol{\Lambda} を CP 分解 \boldsymbol{\Lambda} \simeq \sum_{m=1}^{M} \boldsymbol{u}_{:m} \circ \boldsymbol{v}_{:m} \circ \boldsymbol{w}_{:m} でモデル化します。

目的関数(KLダイバージェンス)

負のポアソン対数尤度はテンソル版の一般化 KL ダイバージェンスに等価で、

\begin{align*} J &= - \log \mathcal{L}(\boldsymbol{U}, \boldsymbol{V}, \boldsymbol{W}) \\ &= \sum_{r,s,t} \left( -\mathbf{F}_{rst} \log \boldsymbol{\Lambda}_{rst} + \boldsymbol{\Lambda}_{rst} + \log(\mathbf{F}_{rst}!) \right), \end{align*}

ここで \boldsymbol{\Lambda}_{rst} = \sum_{m=1}^{M} U_{r m} V_{s m} W_{t m} です。

以下、要素ごとの積を \circ、要素ごとの除算を \oslash とします。また、\odot は Khatri–Rao 積(列ごとの Kronecker 積)、\mathbf{F}_{(n)} はモード n での展開行列(unfolding)を表します。

モード展開の形での再構成は

\begin{align*} \boldsymbol{\Lambda}_{(1)} = \boldsymbol{U} (\boldsymbol{W} \odot \boldsymbol{V})^{\top} \\ \boldsymbol{\Lambda}_{(2)} = \boldsymbol{V} (\boldsymbol{W} \odot \boldsymbol{U})^{\top} \\ \boldsymbol{\Lambda}_{(3)} = \boldsymbol{W} (\boldsymbol{V} \odot \boldsymbol{U})^{\top} \end{align*}

乗法更新式(Multiplicative Update Rules)

NMFの場合と同様、勾配を正負の項に分解して乗法更新を行うことで、非負性を保ちつつ (J) を減少させる更新が得られます。

偏微分の導出(要素形)

\boldsymbol{\Lambda}_{rst} = \sum_{m} U_{rm} V_{sm} W_{tm} とします。まず U_{rm} に関して、

\frac{\partial J}{\partial U_{rm}} = \sum_{s,t} \left( - \frac{\mathbf{F}_{rst}}{\boldsymbol{\Lambda}_{rst}} \cdot \frac{\partial \boldsymbol{\Lambda}_{rst}}{\partial U_{rm}} + \frac{\partial \boldsymbol{\Lambda}_{rst}}{\partial U_{rm}} \right) = \sum_{s,t} \left(1 - \frac{\mathbf{F}_{rst}}{\boldsymbol{\Lambda}_{rst}}\right) V_{sm} W_{tm}.

同様に、V_{sm}W_{tm} に関しては、

\frac{\partial J}{\partial V_{sm}} = \sum_{r,t} \left(1 - \frac{\mathbf{F}_{rst}}{\boldsymbol{\Lambda}_{rst}}\right) U_{rm} W_{tm},\quad \frac{\partial J}{\partial W_{tm}} = \sum_{r,s} \left(1 - \frac{\mathbf{F}_{rst}}{\boldsymbol{\Lambda}_{rst}}\right) U_{rm} V_{sm}.

これらを「正の項 − 負の項」に分解すると、例えば U では

\frac{\partial J}{\partial U_{rm}} = \underbrace{\sum_{s,t} V_{sm} W_{tm}}_{\text{正}}\; -\; \underbrace{\sum_{s,t} V_{sm} W_{tm} \; \frac{\mathbf{F}_{rst}}{\boldsymbol{\Lambda}_{rst}}}_{\text{負}}.

乗法更新は

\begin{align*} U_{rm} \leftarrow U_{rm} \times \frac{\text{負}}{\text{正}}, \\ V_{sm} \leftarrow V_{sm} \times \frac{\text{負}}{\text{正}}, \\ W_{tm} \leftarrow W_{tm} \times \frac{\text{負}}{\text{正}} \end{align*}

により得られます。

要素表現

\begin{align*} U_{r m}^{(t+1)} &= U_{r m}^{(t)} \;\circ\; \frac{\displaystyle \sum_{s,t} V_{s m} W_{t m} \; \frac{\mathbf{F}_{rst}}{\boldsymbol{\Lambda}_{rst}}}{\displaystyle \sum_{s,t} V_{s m} W_{t m}}, \\ V_{s m}^{(t+1)} &= V_{s m}^{(t)} \;\circ\; \frac{\displaystyle \sum_{r,t} U_{r m} W_{t m} \; \frac{\mathbf{F}_{rst}}{\boldsymbol{\Lambda}_{rst}}}{\displaystyle \sum_{r,t} U_{r m} W_{t m}}, \\ W_{t m}^{(t+1)} &= W_{t m}^{(t)} \;\circ\; \frac{\displaystyle \sum_{r,s} U_{r m} V_{s m} \; \frac{\mathbf{F}_{rst}}{\boldsymbol{\Lambda}_{rst}}}{\displaystyle \sum_{r,s} U_{r m} V_{s m}}. \end{align*}

モード展開・Khatri–Rao による行列表現

\mathbf{1}_R, \mathbf{1}_S, \mathbf{1}_T は全1ベクトル、\mathbf{1}_{ST} 等は対応する長さの全1ベクトル。

\begin{align*} \boldsymbol{U}^{(t+1)} &= \boldsymbol{U}^{(t)} \;\circ\; \frac{\left( \mathbf{F}_{(1)} \;\oslash\; \boldsymbol{\Lambda}_{(1)} \right) \; (\boldsymbol{W} \odot \boldsymbol{V})}{\;\mathbf{1}_R\, \left( (\boldsymbol{W} \odot \boldsymbol{V})^{\top} \mathbf{1}_{ST} \right)^{\top}}, \\ \boldsymbol{V}^{(t+1)} &= \boldsymbol{V}^{(t)} \;\circ\; \frac{\left( \mathbf{F}_{(2)} \;\oslash\; \boldsymbol{\Lambda}_{(2)} \right) \; (\boldsymbol{W} \odot \boldsymbol{U})}{\;\mathbf{1}_S\, \left( (\boldsymbol{W} \odot \boldsymbol{U})^{\top} \mathbf{1}_{RT} \right)^{\top}}, \\ \boldsymbol{W}^{(t+1)} &= \boldsymbol{W}^{(t)} \;\circ\; \frac{\left( \mathbf{F}_{(3)} \;\oslash\; \boldsymbol{\Lambda}_{(3)} \right) \; (\boldsymbol{V} \odot \boldsymbol{U})}{\;\mathbf{1}_T\, \left( (\boldsymbol{V} \odot \boldsymbol{U})^{\top} \mathbf{1}_{RS} \right)^{\top}}. \end{align*}

各分子は観測と再構成の要素比(\mathbf{F}_{(n)} \oslash \boldsymbol{\Lambda}_{(n)})を Khatri–Rao 積で重み付け集約したもので、分母は対応する正規化係数(Khatri–Rao 行列の列和)です。形状はそれぞれ \boldsymbol{U}\!: R\times M, \boldsymbol{V}\!: S\times M, \boldsymbol{W}\!: T\times M に一致します。

乗法更新の正当化

NMFと同様、x \mapsto -f\log x に対する接線不等式(Jensen/Youngの不等式)を用いるか、あるいは KL の補助関数(auxiliary function)を構成して、上界を最小化する MM(Majorization–Minimization)により、上記の比形式の更新が J を非増加にすることが示せます。要点は、比 \mathbf{F}/\boldsymbol{\Lambda} による重み付き加重平均と、Khatri–Rao 行列の列和による正規化が、凸上界の最小化解として現れる点にあります。

実装例

from __future__ import annotations

import numpy as np
from scipy.linalg import khatri_rao


class NTF:
    def __init__(
        self,
        n_components: int,
        max_iter: int = 100,
        eps: float = 1e-8,
        random_state: int | None = None,
    ) -> None:
        self.n_components = n_components
        self.max_iter = max_iter
        self.eps = eps
        self.U: np.ndarray[float] | None = None
        self.V: np.ndarray[float] | None = None
        self.W: np.ndarray[float] | None = None
        self.random_state = random_state

    def fit(self, X: np.ndarray[float]) -> NTF:
        np.random.seed(self.random_state)
        R, S, T = X.shape
        U = np.random.rand(R, self.n_components)
        V = np.random.rand(S, self.n_components)
        W = np.random.rand(T, self.n_components)
        
        for _ in range(self.max_iter):
            VW_kr = khatri_rao(V, W)
            X1 = X.reshape(R, S*T)
            Lambda1 = (U @ VW_kr.T).reshape(R, S*T) + self.eps
            numerator = (X1 / Lambda1) @ VW_kr
            denominator = np.ones((R, 1)) @ (VW_kr.T @ np.ones((S*T, 1))).T
            U = U * numerator / (denominator + self.eps)

            WU_kr = khatri_rao(W, U)
            X2 = np.transpose(X, (1, 2, 0)).reshape(S, T*R)
            Lambda2 = (V @ WU_kr.T).reshape(S, T*R) + self.eps
            numerator = (X2 / Lambda2) @ WU_kr
            denominator = np.ones((S, 1)) @ (WU_kr.T @ np.ones((T*R, 1))).T
            V = V * numerator / (denominator + self.eps)

            UV_kr = khatri_rao(U, V)
            X3 = np.transpose(X, (2, 0, 1)).reshape(T, R*S)
            Lambda3 = (W @ UV_kr.T).reshape(T, R*S) + self.eps
            numerator = (X3 / Lambda3) @ UV_kr
            denominator = np.ones((T, 1)) @ (UV_kr.T @ np.ones((R*S, 1))).T
            W = W * numerator / (denominator + self.eps)
        self.U = U
        self.V = V
        self.W = W
            
        return self

参考

  1. https://qiita.com/K_Noguchi/items/35d8ff52d3dc87bb61d0
    • こちらの記事ではeinsumを使ったテンソル演算が行われています

Discussion