行列積状態について考える (12) — テンソルネットワークで MNIST 分類
目的
文献 [E] arXiv:1906.06329「TensorNetwork for Machine Learning」を読んで PyTorch で実装してみたので記事にする。
テンソルネットワーク画像分類器
以下のようなテンソルトレインを用いた画像分類モデルを考える。
論文の FIG. 2. のテンソルのノードに適当に名前を付けるとこのような感じになる。
ここで、
を想定しているので
例えば 6 画素の画像は平坦化することで
となる。これに対して、先ほどのテンソル分類器
ここからは機械学習勢おなじみ
後は機械学習でよくあるように、softmax を通して各ラベルごとの確率を出して、正解ラベルとのクロスエントロピー損失をとれば良い ということになる。
ここまでの内容は文献 [S] で扱われいるものであるが、計算方法が 1992 年に S. White によって開発された DMRG (密度行列繰り込み群) アルゴリズムという計算物理の手法によっているため、文献 [E](解説が文献 [N] 第 8 章に詳しい)では自動微分を用いることで機械学習の実践者に優しい内容にしたということである。オリジナルは TensorFlow を用いているが、今回は PyTorch を使ってスクラッチから実装してみた。
データとテンソルトレインの縮約
各
を掘り下げる。右辺を明示的に書くと
となる。式 (1) を思い出すと全体として、“脚” に考えられる全パターンの組み合わせが実行されることになる。
数値計算上の懸念
テンソルトレインの各ノードを
となる。
次に
の部分だが、組み合わせ数は torch.float32
の範囲で扱われる必要が出てくる。
また、縮約計算中に
FIG. 3. に詳細が書かれているが、今回は少々さぼって両端から 2 特徴量ずつ関連するテンソルを縮約していくことにした。
テンソルの初期化
実際、上記にように計算がデリケートなため、ランダム初期化を用いてしまうと
ランダム初期化でもいけるかもしれないが、いけるパターンを探すのがどんどん難しくなり、28x28 の MNIST では断念した。7x7 までリサイズした MNIST だとある程度は学習できるパターンが見つかったが、今回は断念した。
実装
以上を踏まえて実装したい。テストは Google Colab 上で T4 を用いて行った。
必要なモジュールの import
from __future__ import annotations
import math
import os
import sys
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.optim as optim
import torchinfo
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from torch import nn
from torch.utils.data import DataLoader
MNIST のダウンロード
root = os.path.join(os.getenv('HOME'), '.torch')
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda")
default_pix_dims = 2 # dimension of feature space
default_bond_dims = 10 # common dimension of virtual indices
img_size = 28
n_feature = img_size * img_size
transform = transforms.Compose([
transforms.ToTensor(),
])
trainset = torchvision.datasets.QMNIST(
root=root,
train=True,
download=True,
transform=transform
)
testset = torchvision.datasets.QMNIST(
root=root,
train=False,
download=True,
transform=transform
)
データローダーの作成
trainloader = torch.utils.data.DataLoader(
trainset,
batch_size = 32,
shuffle=True,
)
testloader = torch.utils.data.DataLoader(
testset,
batch_size = 32,
shuffle=False,
)
モデルの実装
def make_tt(
n_feature: int, pix_dims: int, bond_dims: int, n_class: int,
classifier_idx: int | None = None, std: float = 1e-3,
) -> list[torch.Tensor]:
tt_cores = []
for i in range(n_feature):
if i == 0:
dims = (pix_dims, bond_dims)
core = torch.zeros(dims)
core[:, 0] = 1
core += torch.normal(mean=0.0, std=std, size=core.shape)
elif i == n_feature - 1:
dims = (bond_dims, pix_dims)
core = torch.zeros(dims)
core[0, :] = 1
core += torch.normal(mean=0.0, std=std, size=core.shape)
else:
dims = (bond_dims, pix_dims, bond_dims)
core = torch.tensor(
np.array(pix_dims * [np.eye(bond_dims)],
dtype=np.float32)
).permute(1, 0, 2)
core += torch.normal(mean=0.0, std=std, size=core.shape)
tt_cores.append(core)
if classifier_idx is not None:
dims = (bond_dims, n_class, bond_dims)
core = torch.tensor(
np.array(n_class * [np.eye(bond_dims)], dtype=np.float32)
).permute(1, 0, 2)
core += torch.normal(mean=0.0, std=std, size=core.shape)
tt_cores.insert(classifier_idx, core)
return tt_cores
class FeatureEmbeddingLayer(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: torchTensor) -> tuple[torchTensor, ...]:
x = torch.flatten(x, start_dim=1)
x = torch.stack([1-x, x], axis=1).permute((2, 0, 1))
x = tuple(t.squeeze() for t in x.split(1))
# n_feature tensors whose shape is (n_batch, pix_dims)
return x
class WeightLayer(nn.Module):
def __init__(self, n_feature: int, pix_dims: int, bond_dims: int, n_class: int):
super().__init__()
self.n_feature = n_feature
classifier_idx = self.classifier_loc(n_feature)
tt_cores = make_tt(
n_feature, pix_dims, bond_dims, n_class, classifier_idx=classifier_idx
)
self.n_cores = len(tt_cores)
for i, core in enumerate(tt_cores):
param_core = nn.parameter.Parameter(core)
setattr(self, f"tt_core{i}", param_core)
def forward(self, x: tuple[torchTensor, ...], n_sub_features: int = 2):
classifier_idx = self.classifier_loc(self.n_feature)
assert(
n_sub_features * 2 < self.n_feature,
f"{n_sub_features*2=} must be < {self.n_feature=}"
)
n_left_right_block, n_remaining_fea = \
self.left_sub_feature_num(n_sub_features)
start = 0
prev_t = None
for i in range(n_left_right_block):
end = start + n_sub_features
equation = self._make_equation(
self.n_feature, start_fea=start, end_fea=end
)
t = torch.einsum(equation, *x[start:end], *self.tt_cores[start:end])
if prev_t is None:
prev_t = t
else:
prev_t = torch.einsum("Bb,Bbc->Bc", prev_t, t)
start += n_sub_features
left_t = prev_t
start = n_sub_features * (n_left_right_block * 2 - 1) + n_remaining_fea
prev_t = None
for i in range(n_left_right_block):
end = start + n_sub_features
equation = self._make_equation(
self.n_feature, start_fea=start, end_fea=end
)
# +1 for tt_cores means consideration of shift for the classifier site
t = torch.einsum(
equation, *x[start:end], *self.tt_cores[start + 1:end + 1]
)
if prev_t is None:
prev_t = t
else:
prev_t = torch.einsum("Bbc,Bc->Bb", t, prev_t)
start -= n_sub_features
right_t = prev_t
start = n_left_right_block * n_sub_features
end = start + n_remaining_fea
equation = self._make_equation(self.n_feature, start_fea=start, end_fea=end)
# +1 for tt_cores means consideration of including the classifier site
classifier_t = torch.einsum(
equation, *x[start:end], *self.tt_cores[start:end + 1]
)
output = torch.einsum("ab,abAc,ac->aA", left_t, classifier_t, right_t)
return output
def forward_full(self, x: tuple[torchTensor, ...]):
equation = self._make_equation(self.n_feature)
output = torch.einsum(equation, *x, *self.tt_cores)
return output
def left_sub_feature_num(self, n_sub_features) -> tuple[int, int]:
remaining = self.n_feature
i = 0
while remaining > n_sub_features * 2:
remaining = remaining - n_sub_features * 2
i += 1
return i, remaining
@property
def tt_cores(self):
return [getattr(self, f"tt_core{i}") for i in range(self.n_cores)]
@classmethod
def classifier_loc(cls, n_feature):
return n_feature // 2
@classmethod
def _make_equation(
cls, n_feature: int, start_fea: int = 0, end_fea: int | None = None,
batch_c: str = "a", class_c: str | None = "A"
):
"""make an equation for range(start_i, end_i)
"""
if start_fea < 0:
raise ValueError(f'{start_fea=} must be >= 0.')
if end_fea is None:
end_fea = n_feature
if end_fea > n_feature:
raise ValueError(f'{end_fea=} must be <= {n_feature=}.')
if start_fea >= end_fea:
raise ValueError(f'{start_fea=} must be less than {end_fea=}.')
if end_fea - start_fea > ord("Z") - ord("B") + 1:
raise ValueError(
f'{end_fea=} - {start_fea=} must be less than {ord("Z")-ord("B")+1}.'
)
classifier_idx = cls.classifier_loc(n_feature)
includes_classifier = start_fea <= classifier_idx <= end_fea - 1
fea_i = ord("B") # U+0042-
vir_i = ord("b") # U+0062-
fea_idx: list[str] = []
vir_idx: list[str] = [] # ["Bb", "bCc", "cD"]
pre_vir_idx: list[tuple[str, ...]] = []
for i in range(start_fea, end_fea):
fea_idx.append(chr(fea_i))
if i == 0:
vir_i -= 1 # preserve first vir_i
pre_vir_idx.append((chr(fea_i), chr(vir_i + 1)))
elif i == n_feature - 1:
pre_vir_idx.append((chr(vir_i), chr(fea_i)))
else:
pre_vir_idx.append((chr(vir_i), chr(fea_i), chr(vir_i + 1)))
fea_i += 1
vir_i += 1
classifier_loc = classifier_idx - start_fea
for i, idx in enumerate(pre_vir_idx):
if includes_classifier:
if i == classifier_loc:
if i == 0:
classifier_idx = f"{idx[0]}{class_c}{chr(ord(idx[0]) + 1)}"
vir_idx.append(classifier_idx)
else:
last_vir_idx = vir_idx[-1][-1]
classifier_idx = \
f"{last_vir_idx}{class_c}{chr(ord(last_vir_idx) + 1)}"
vir_idx.append(classifier_idx)
if i >= classifier_loc:
idx = [c if ord(c) < ord("a") else chr(ord(c) + 1)
for c in list(idx)]
vir_idx.append("".join(idx))
fea_idx_s = ",".join([f"{batch_c}{c}"for c in fea_idx])
out_index = batch_c
vir_idx_s = ",".join(vir_idx)
if ord(vir_idx_s[0]) >= ord("a"):
out_index += vir_idx_s[0]
if class_c in vir_idx_s:
out_index += class_c
if ord(vir_idx_s[-1]) >= ord("a"):
out_index += vir_idx_s[-1]
equation = f'{fea_idx_s + "," + vir_idx_s}->{out_index}'
return equation
class Net(nn.Module):
def __init__(self, n_feature: int, pix_dims: int, bond_dims: int, n_class: int):
super().__init__()
self.fea_layer = FeatureEmbeddingLayer()
self.wgt_layer = WeightLayer(
n_feature=n_feature, pix_dims=pix_dims, bond_dims=bond_dims,
n_class=n_class
)
self.softmax = nn.LogSoftmax(dim=1)
def forward(self, x):
x = self.fea_layer(x)
x = self.wgt_layer(x)
x = self.softmax(x)
return x
訓練と検証ループの実装
def train(net, device, train_loader, optimizer, epoch, log_interval):
losses = []
nll_loss = nn.NLLLoss()
net.train()
running_loss = 0
n_samples = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = net(data)
loss = nll_loss(output, target)
loss.backward()
optimizer.step()
running_loss += loss.item()
n_samples += len(data)
if batch_idx % log_interval == 0:
losses.append(running_loss / n_samples)
running_loss = 0
n_samples = 0
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item() / len(data)))
return losses
def test(net, device, test_loader):
nll_loss = nn.NLLLoss()
net.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = net(data)
test_loss += nll_loss(output, target).item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
モデルの作成と訓練と検証
%%time
net = Net(n_feature=n_feature, pix_dims=default_pix_dims,
bond_dims=default_bond_dims, n_class=10).to(device)
optimizer = optim.Adam(net.parameters())
log_interval = 50
epochs = 1
losses = []
for epoch in range(1, epochs + 1):
sublosses = train(net, device, trainloader, optimizer, epoch, log_interval)
losses += sublosses
test(net, device, testloader)
Train Epoch: 1 [0/60000 (0%)] Loss: 0.071956
Train Epoch: 1 [1600/60000 (3%)] Loss: 0.071391
Train Epoch: 1 [3200/60000 (5%)] Loss: 0.071370
Train Epoch: 1 [4800/60000 (8%)] Loss: 0.033660
Train Epoch: 1 [6400/60000 (11%)] Loss: 0.045731
Train Epoch: 1 [8000/60000 (13%)] Loss: 0.028915
Train Epoch: 1 [9600/60000 (16%)] Loss: 0.029875
Train Epoch: 1 [11200/60000 (19%)] Loss: 0.032353
Train Epoch: 1 [12800/60000 (21%)] Loss: 0.018047
Train Epoch: 1 [14400/60000 (24%)] Loss: 0.023287
Train Epoch: 1 [16000/60000 (27%)] Loss: 0.022523
Train Epoch: 1 [17600/60000 (29%)] Loss: 0.026113
Train Epoch: 1 [19200/60000 (32%)] Loss: 0.015155
Train Epoch: 1 [20800/60000 (35%)] Loss: 0.017743
Train Epoch: 1 [22400/60000 (37%)] Loss: 0.023616
Train Epoch: 1 [24000/60000 (40%)] Loss: 0.020549
Train Epoch: 1 [25600/60000 (43%)] Loss: 0.021021
Train Epoch: 1 [27200/60000 (45%)] Loss: 0.019376
Train Epoch: 1 [28800/60000 (48%)] Loss: 0.020880
Train Epoch: 1 [30400/60000 (51%)] Loss: 0.030461
Train Epoch: 1 [32000/60000 (53%)] Loss: 0.037976
Train Epoch: 1 [33600/60000 (56%)] Loss: 0.013896
Train Epoch: 1 [35200/60000 (59%)] Loss: 0.006066
Train Epoch: 1 [36800/60000 (61%)] Loss: 0.009239
Train Epoch: 1 [38400/60000 (64%)] Loss: 0.017888
Train Epoch: 1 [40000/60000 (67%)] Loss: 0.003709
Train Epoch: 1 [41600/60000 (69%)] Loss: 0.031147
Train Epoch: 1 [43200/60000 (72%)] Loss: 0.023640
Train Epoch: 1 [44800/60000 (75%)] Loss: 0.005690
Train Epoch: 1 [46400/60000 (77%)] Loss: 0.012589
Train Epoch: 1 [48000/60000 (80%)] Loss: 0.007181
Train Epoch: 1 [49600/60000 (83%)] Loss: 0.014580
Train Epoch: 1 [51200/60000 (85%)] Loss: 0.007948
Train Epoch: 1 [52800/60000 (88%)] Loss: 0.012705
Train Epoch: 1 [54400/60000 (91%)] Loss: 0.020037
Train Epoch: 1 [56000/60000 (93%)] Loss: 0.013667
Train Epoch: 1 [57600/60000 (96%)] Loss: 0.023198
Train Epoch: 1 [59200/60000 (99%)] Loss: 0.015644Test set: Average loss: 0.0131, Accuracy: 54402/60000 (90.67%)
CPU times: user 47min 25s, sys: 5.34 s, total: 47min 31s
Wall time: 47min 37s
損失の推移
plt.plot(losses)
plt.show()
驚きの高次元
まとめ
とりあえず、何とか論文が主張するような精度が出る実装ができて良かった。
ところで論文によると
The purpose of this note is to be a resource for machine learning practitioners who wish to learn how tensor networks can be applied to classification problems.
ということなのだが、machine learning practitioners の立場としてはちょっときつかった・・・。結構簡単に使えるライブラリと初期化手法、そしてパフォーマンス等々が用意されていないと、自分で実装するのは大変だなと思った。勿論 TensorNetwork があるにはあるが、ただの machine learning practitioners が使いこなすには テンソルネットワークの概念を含めて ハードルが高いな・・・と感じるわけである。
どちらかと言うと 行列積状態について考える (5) — ニューラルネットワークのモデル圧縮 のように、既に訓練が完了しているニューラルネットワークをテンソルネットワークに変換して微調整を行うといった用途のほうが扱いも簡単だと感じていて、スクラッチでテンソルネットワークを訓練するのはきついなと感じるのだがどうなのだろうか。
参考文献
[E] TensorNetwork for Machine Learning, arXiv:1906.06329, Stavros Efthymiou, Jack Hidary, Stefan Leichenauer
[M] MPS_classifier@TensorNetwork, GitHub, Google LLC
[S] Supervised Learning with Quantum-Inspired Tensor Networks, arXiv:1605.05775, E. Miles Stoudenmire, David J. Schwab
[N] 西野友年, テンソルネットワーク入門, 講談社, 2023
Discussion
Zennに掲載されている記事を拝見し、非常に興味深く読ませていただきました。記事内のコードを実行してみましたが、残念ながら記事に記載されている結果を再現することができませんでした。
具体的には、トレーニング中のLossが0.000000のまま変化せず、学習が進んでいないように見受けられます。
私の環境で結果を再現できない原因について、いくつか質問させていただきたいです:
お忙しいところ恐縮ですが、ご回答いただけますと幸いです。どうぞよろしくお願いいたします。
@rihitosakurai
ご質問ありがとうございます。詳細は忘れかけていますが、以下の部分だったと思います。
参考までに調査段階では私も以下のようなグラフを得たことがありました。
今のところうまくいく初期化は
が一番良さそうという以上には掘り下げていません。PyTorch だからという理由かもしれず、TensorFlow や JAX だとまた異なるのかもしれませんが、これも調査できていません。
最終的には unofficial 実装として https://github.com/derwind/machine_learning/tree/742310fe4c5fae214aa4bc9490a6289ca4006f72/arxiv/1906.06329 にコードを置いていますのでご興味がありましたら。確か Google Colaboratory の NVIDIA T4 で訓練を回したように思います。
make_tt
関数でuse_google_initialization=True
の時の処理が重みテンソルの初期化のはずです。ありがとうございます。確かにgoogle Colaboratory の NVIDIA T4 で同様の結果が再現できました。
また手元の環境においてもパッケージのバージョンをgoogle Colaboratoryと同じにしたところ、上記の問題は解決できました。