iTranslated by AI
Matrix Product States (5) — Neural Network Model Compression
Purpose
As a continuation of Thinking about Matrix Product States (3), I want to consider model compression for neural networks. I would like to proceed along the lines of arXiv:1509:06569 Tensorizing Neural Networks.
Fully Connected Layers and TT-Layers
A fully connected layer in a neural network is something that gives an output
Below, we will look at "3.1 TT-representations for vectors and matrices" in arXiv:1509:06569.
Matrices
Let's look at the matrix
Suppose
With
By performing a TT-decomposition
This corresponds to equation (3) in arXiv:1509:06569. The paper refers to this representation as a "TT-matrix."
Vectors
For the vector
TT-Layers
By converting this into a tensor and further TT-decomposing the matrix part into a TT-matrix, we get:
The paper refers to a fully connected layer transformed in this way as a "TT-layer."
By making the dimension of each bond in the Tensor-Train smaller than
Implementation
Since looking only at the theory can leave one feeling like they understand but also don't quite understand, I'll try implementing it in Python.
I will use the TT_SVD function I tested in Thinking about Matrix Product States (3). I previously took the transpose .T for the last element of the Tensor-Train, but since that felt a bit awkward, I won't do it this time. I will also add automatic bond dimension calculation and check functions as follows.
Importing Necessary Modules
from __future__ import annotations
from typing import Sequence
import numpy as np
Slightly Improved TT_SVD
def TT_SVD(
C: np.ndarray, r: Sequence[int] | None = None, check_r: bool = False
) -> list[np.ndarray]:
"""TT_SVD algorithm
Args:
C (np.ndarray): n-dimensional input tensor
r (Sequence[int]): a list of bond dimensions.
If `r` is None, `r` will be automatically calculated
check_r (bool): check if `r` is valid
Returns:
list[np.ndarray]: a list of core tensors of TT-decomposition
"""
dims = C.shape
n = len(dims) # n-dimensional tensor
if r is None or check_r:
# Theorem 2.1
r_ = []
for sep in range(1, n):
row_dim = np.prod(dims[:sep])
col_dim = np.prod(dims[sep:])
rank = np.linalg.matrix_rank(C.reshape(row_dim, col_dim))
r_.append(rank)
if r is None:
r = r_
if len(r) != n - 1:
raise ValueError(f"{len(r)=} must be {n - 1}.")
if check_r:
for i, (r1, r2) in enumerate(zip(r, r_)):
if r1 > r2:
raise ValueError(f"{i}th dim {r1} must not be larger than {r2}.")
# Algorithm 1
tt_cores = []
for i in range(n - 1):
if i == 0:
ri_1 = 1
else:
ri_1 = r[i - 1]
ri = r[i]
C = C.reshape(ri_1 * dims[i], np.prod(dims[i + 1:]))
U, S, Vh = np.linalg.svd(C, full_matrices=False)
# approximation
U = U[:, :ri]
S = S[:ri]
Vh = Vh[:ri, :]
tt_cores.append(U.reshape(ri_1, dims[i], ri))
C = np.diag(S) @ Vh
tt_cores.append(C)
tt_cores[0] = tt_cores[0].reshape(dims[0], r[0])
return tt_cores
Definition of Fully Connected Layer and Mini-batch
In machine learning, it is standard for the input
rng = np.random.default_rng(12345)
batch_dim = 16
# a fully connected layer
w = rng.standard_normal((3*4, 5*6))
b = rng.standard_normal(3*4)
# mini batch
x = rng.standard_normal((batch_dim, 5*6))
Ultimately, we want to obtain the following result using a TT-layer.
answer = x @ w.T + b
print(f"num of params={np.prod(w.shape) + np.prod(b.shape)}")
num of params=372
Warm-up
Let's gradually transition to tensor calculations.
Writing standard matrix calculations with contraction
val1 = np.einsum("Nj,ij->Ni", x, w) + b # use N as the index symbol for batch dim
print(f"{np.allclose(answer, val1)=}")
np.allclose(answer, val1)=True
Since this is the most straightforward matrix calculation, it naturally matches.
Reshaping into tensors for calculation
Let's add a bit more of a tensor flavor.
W = w.reshape(3, 4, 5, 6)
X = x.reshape(batch_dim, 5, 6)
B = b.reshape(3, 4)
print(f"num of params={np.prod(W.shape) + np.prod(B.shape)}")
val2 = np.einsum("Nkl,ijkl->Nij", X, W) + B
val2 = val2.reshape(batch_dim, -1)
print(f"{np.allclose(answer, val2)=}")
num of params=372
np.allclose(answer, val2)=True
At this point, the number of parameters has not changed, and the calculation result matches the original.
TT-Decomposing the weight matrix W
Next, we decompose W using TT-decomposition.
tt_W = TT_SVD(W, [3, 12, 6])
print("tt_W:", [v.shape for v in tt_W])
print(f"num of params={np.sum([np.prod(v.shape) for v in tt_W]) + np.prod(B.shape)}")
W1 = np.einsum("ic,cjd,dke,el->ijkl", *tt_W)
print(f"{np.allclose(W, W1)=}")
tt_W: [(3, 3), (3, 4, 12), (12, 5, 6), (6, 6)]
num of params=561
np.allclose(W, W1)=True
The tensor W and the Tensor-Train tt_W are essentially the same, but the number of parameters has increased. When the exact values can be reproduced, parameters increase by the amount of the bonds.
Continuing, we execute the calculation as a TT-layer.
val3 = np.einsum("Nkl,ic,cjd,dke,el->Nij", X, *tt_W) + B
val3 = val3.reshape(batch_dim, -1)
print(f"{np.allclose(answer, val3)=}")
print(f"max diff={np.round(np.max(np.abs(answer - val3)), 5)}")
np.allclose(answer, val3)=True
max diff=0.0
The results match the original x @ w.T + b calculation.
Low-rank Approximation
As it stands, we've only increased the number of parameters without changing the calculation results, so there's no real benefit. However, by reducing the bond dimensions
In this case, we've converted the weight matrix into a 4th-order tensor before TT-decomposition, so there are 3 bonds, and the TT-rank required to reconstruct the original tensor is (3, 12, 6). Let's gradually decrease the middle TT-rank from 12.
for dim in range(12, 5, -1):
tt_W1 = TT_SVD(W, [3, dim, 6])
print("tt_W1:", [v.shape for v in tt_W1])
print(f"num of params={np.sum([np.prod(v.shape) for v in tt_W1]) + np.prod(B.shape)}")
val4 = np.einsum("Nkl,ic,cjd,dke,el->Nij", X, *tt_W1) + B
val4 = val4.reshape(batch_dim, -1)
print(f"{np.allclose(answer, val4)=} for {dim=}")
print(f"max diff={np.round(np.max(np.abs(answer - val4)), 5)} for {dim=}")
print()
tt_W1: [(3, 3), (3, 4, 12), (12, 5, 6), (6, 6)]
num of params=561
np.allclose(answer, val4)=True for dim=12
max diff=0.0 for dim=12tt_W1: [(3, 3), (3, 4, 11), (11, 5, 6), (6, 6)]
num of params=519
np.allclose(answer, val4)=False for dim=11
max diff=2.46626 for dim=11tt_W1: [(3, 3), (3, 4, 10), (10, 5, 6), (6, 6)]
num of params=477
np.allclose(answer, val4)=False for dim=10
max diff=3.92577 for dim=10tt_W1: [(3, 3), (3, 4, 9), (9, 5, 6), (6, 6)]
num of params=435
np.allclose(answer, val4)=False for dim=9
max diff=3.57042 for dim=9tt_W1: [(3, 3), (3, 4, 8), (8, 5, 6), (6, 6)]
num of params=393
np.allclose(answer, val4)=False for dim=8
max diff=5.37306 for dim=8tt_W1: [(3, 3), (3, 4, 7), (7, 5, 6), (6, 6)]
num of params=351
np.allclose(answer, val4)=False for dim=7
max diff=6.59231 for dim=7tt_W1: [(3, 3), (3, 4, 6), (6, 5, 6), (6, 6)]
num of params=309
np.allclose(answer, val4)=False for dim=6
max diff=6.45532 for dim=6
In the last two cases, the number of parameters is smaller than the original 372. In exchange, the calculation error has increased. It is a trade-off between how much accuracy you seek and how much you want to reduce the parameter count.
Bonus (MPS Representation of Quantum States)
Back when I was writing Thinking about Matrix Product States (2), I was dealing with quantum states, but I realized I had shifted to neural networks before I knew it.
Since I'm at it, let's apply the TT_SVD function to quantum states as well and look at the MPS representation.
First, let's make preparations.
ket_ZERO = np.array([1, 0], dtype=float)
ket_ONE = np.array([0, 1], dtype=float)
Trying with \ket{000}
First, prepare the state vector.
state_000 = np.kron(np.kron(ket_ZERO, ket_ZERO), ket_ZERO)
state_000
array([1., 0., 0., 0., 0., 0., 0., 0.])
Next, let's try the TT-decomposition.
mps_state_000 = TT_SVD(state_000.reshape(2, 2, 2))
print([v.shape for v in mps_state_000])
mps_state_000
[(2, 1), (1, 2, 1), (1, 2)]
[array([[1.],
[0.]]),
array([[[1.],
[0.]]]),
array([[1., 0.]])]
Although the appearance is slightly different, I think it corresponds to the result obtained in Thinking about Matrix Product States (2).
The original state vector can be reconstructed with the following contraction calculation.
np.einsum("ia,ajb,bk->ijk", *mps_state_000).flatten()
array([1., 0., 0., 0., 0., 0., 0., 0.])
Trying with \frac{1}{\sqrt{2}}(\ket{000} + \ket{111})
Similarly, prepare the state vector.
state_111 = np.kron(np.kron(ket_ONE, ket_ONE), ket_ONE)
state_ghz = (state_000 + state_111) / np.sqrt(2)
state_ghz
array([0.70710678, 0. , 0. , 0. , 0. , 0. , 0. , 0.70710678])
Next, perform the TT-decomposition.
mps_state_ghz = TT_SVD(state_ghz.reshape(2, 2, 2))
print([v.shape for v in mps_state_ghz])
mps_state_ghz
[(2, 2), (2, 2, 2), (2, 2)]
[array([[1., 0.],
[0., 1.]]),
array([[[ 1., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., -1.]]]),
array([[ 0.70710678, 0. ],
[ 0. , -0.70710678]])]
I think this also corresponds to the result obtained in Thinking about Matrix Product States (2).
The original state vector can be reconstructed with the following contraction calculation.
np.einsum("ia,ajb,bk->ijk", *mps_state_ghz).flatten()
array([0.70710678, 0. , 0. , 0. , 0. , 0. , 0. , 0.70710678])
Summary
Following arXiv:1509:06569 Tensorizing Neural Networks, we converted the fully connected layers of a neural network into "TT-layers" through TT-decomposition and performed forward propagation calculations using contraction, confirming that the results match standard linear algebra calculations.
We also confirmed that model compression can be achieved by reducing bond dimensions, although this comes at the cost of increased error.
Furthermore, we confirmed that applying this TT-decomposition to quantum state vectors yields the MPS representations we have previously seen.
References
[O] Tensor-Train Decomposition, SIAM J. Sci. Comput., 33(5), 2295–2317. (23 pages), I. V. Oseledets
[S] The density-matrix renormalization group in the age of matrix product states, arXiv:1008.3477, Ulrich Schollwoeck
[NPOV] Tensorizing Neural Networks, arXiv:1509.06569, Alexander Novikov, Dmitry Podoprikhin, Anton Osokin, Dmitry Vetrov
Discussion