pytorchの多変量正規分布でハマった
概要
共分散行列計算してをpytorchの多変量正規分布に用いようとすると非常に小さい固有値がある場合に計算精度の観点からかエラーがでてしまう。
最終的にlinalg.eigで複素数としての固有値問題を解いて実数部を使うことで解決した。
試行
import torch
import torch.nn as nn
import torchvision as tv
import torchvision.transforms as T
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.distributions import multivariate_normal
/usr/local/lib/python3.8/dist-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: '/usr/local/lib/python3.8/dist-packages/torchvision/image.so: undefined symbol: _ZN3c104cuda9SetDeviceEi'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?
warn(
torch.__version__
'2.0.1+cu117'
画像データ(MNIST)の共分散行列から多変量正規分布を作りたい
tr = T.Compose([T.ToTensor()])
dataset = tv.datasets.MNIST('data', True, transform = tr, download = True)
loader = DataLoader(dataset, batch_size = 64, shuffle = True, num_workers = 0)
class AverageMeter:
def __init__(self, name=None):
self.name = name
self.reset()
def reset(self):
self.sum = self.count = self.avg = 0
def update(self, val, n=1):
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
loop = tqdm(loader, position = 0, leave = True)
loss_ = AverageMeter()
for idx, (img, class_lbl) in enumerate(loop):
img=img+torch.rand_like(img)*1e-9
#print(img)
orgimg=img
batchnum=img.shape[0]
print("batch num")
print(batchnum)
print("inpuit img")
print(img.shape)
img=torch.flatten(img,start_dim=1)
print(img.shape)
print("mean")
mean=torch.mean(img, dim=0)
#mean=torch.reshape(mean,mean.shape[1:])
print(mean.shape)
print("covariance")
cov=torch.cov(img)
print(cov.shape)
covt=torch.cov(img.T)
print(covt.shape)
# covt=torch.eye(covt.shape[0])
dist = multivariate_normal.MultivariateNormal(loc=mean, covariance_matrix=covt)
print(dist.sample().shape)
break
0%| | 0/938 [00:00<?, ?it/s]
batch num
64
inpuit img
torch.Size([64, 1, 28, 28])
torch.Size([64, 784])
mean
torch.Size([784])
covariance
torch.Size([64, 64])
torch.Size([784, 784])
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/tmp/ipykernel_7598/3639241380.py in <module>
36 print(covt.shape)
37 # covt=torch.eye(covt.shape[0])
---> 38 dist = multivariate_normal.MultivariateNormal(loc=mean, covariance_matrix=covt)
39 print(dist.sample().shape)
40 break
/usr/local/lib/python3.8/dist-packages/torch/distributions/multivariate_normal.py in __init__(self, loc, covariance_matrix, precision_matrix, scale_tril, validate_args)
148
149 event_shape = self.loc.shape[-1:]
--> 150 super().__init__(batch_shape, event_shape, validate_args=validate_args)
151
152 if scale_tril is not None:
/usr/local/lib/python3.8/dist-packages/torch/distributions/distribution.py in __init__(self, batch_shape, event_shape, validate_args)
60 valid = constraint.check(value)
61 if not valid.all():
---> 62 raise ValueError(
63 f"Expected parameter {param} "
64 f"({type(value).__name__} of shape {tuple(value.shape)}) "
ValueError: Expected parameter covariance_matrix (Tensor of shape (784, 784)) of distribution MultivariateNormal(loc: torch.Size([784]), covariance_matrix: torch.Size([784, 784])) to satisfy the constraint PositiveDefinite(), but found invalid values:
tensor([[ 1.0667e-19, 2.0079e-20, -7.3494e-21, ..., 3.0668e-21,
7.8143e-21, -1.8431e-20],
[ 2.0079e-20, 8.3648e-20, -6.6314e-22, ..., -4.2122e-21,
-2.8440e-22, 1.3152e-22],
[-7.3494e-21, -6.6314e-22, 7.5718e-20, ..., -9.0270e-21,
3.2225e-21, 5.1655e-21],
...,
[ 3.0668e-21, -4.2122e-21, -9.0270e-21, ..., 9.0798e-20,
-4.8592e-21, -3.2140e-22],
[ 7.8143e-21, -2.8440e-22, 3.2225e-21, ..., -4.8592e-21,
7.8224e-20, -1.3207e-20],
[-1.8431e-20, 1.3152e-22, 5.1655e-21, ..., -3.2140e-22,
-1.3207e-20, 9.1553e-20]])
エラーが出てしまった
MultivariateNormalについて2
m = multivariate_normal.MultivariateNormal(torch.zeros(2), torch.eye(2))
m.sample()
tensor([-0.1473, -3.0313])
When I use torch.cov to compute the covariance matrix of a batch of vector (in which the batch_size may be less than the vector length,) and then use the statistics to construct the MultivariateNormal distribution, it would raise a ValueError which indicates the covariance_matrix is not a positive-definite matrix.
import torch
from torch.distributions.multivariate_normal import MultivariateNormal
data = torch.rand(2,4)
data_mean = data.mean(0)
data_cov = torch.cov(data.T)
m = MultivariateNormal(data_mean, data_cov)
When data (NxD) with N<D, this would happen since the covariance_matrix may be positive-semidefinite. I found that in NumPy, the MultivariateNormal distribution only requires positive-semidefinite but PyTorch need positive-definite covariance_matrix.
To achieve my goal, I need to transform the covariance_matrix to the numpy format for random sampling and then transform the sampled results back into PyTorch tensor.
from torch.distributions.multivariate_normal import MultivariateNormal
data = torch.rand(2,4)
data_mean = data.mean(0)
data_cov = torch.cov(data.T)
print(data_cov)
print(torch.linalg.eig(data_cov))
m = MultivariateNormal(data_mean, data_cov)
print(m)
tensor([[ 0.1483, -0.0308, 0.0772, -0.1051],
[-0.0308, 0.0064, -0.0160, 0.0218],
[ 0.0772, -0.0160, 0.0402, -0.0547],
[-0.1051, 0.0218, -0.0547, 0.0744]])
torch.return_types.linalg_eig(
eigenvalues=tensor([ 2.6933e-01+0.j, 0.0000e+00+0.j, -5.0075e-10+0.j, 2.1579e-09+0.j]),
eigenvectors=tensor([[ 0.7421+0.j, -0.6703+0.j, 0.6571+0.j, 0.2764+0.j],
[-0.1540+0.j, -0.1705+0.j, -0.0198+0.j, -0.1052+0.j],
[ 0.3863+0.j, 0.4277+0.j, -0.4820+0.j, 0.5187+0.j],
[-0.5257+0.j, -0.5820+0.j, 0.5792+0.j, 0.8022+0.j]]))
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/tmp/ipykernel_146999/553319903.py in <module>
6 print(data_cov)
7 print(torch.linalg.eig(data_cov))
----> 8 m = MultivariateNormal(data_mean, data_cov)
9 print(m)
/usr/local/lib/python3.8/dist-packages/torch/distributions/multivariate_normal.py in __init__(self, loc, covariance_matrix, precision_matrix, scale_tril, validate_args)
148
149 event_shape = self.loc.shape[-1:]
--> 150 super().__init__(batch_shape, event_shape, validate_args=validate_args)
151
152 if scale_tril is not None:
/usr/local/lib/python3.8/dist-packages/torch/distributions/distribution.py in __init__(self, batch_shape, event_shape, validate_args)
60 valid = constraint.check(value)
61 if not valid.all():
---> 62 raise ValueError(
63 f"Expected parameter {param} "
64 f"({type(value).__name__} of shape {tuple(value.shape)}) "
ValueError: Expected parameter covariance_matrix (Tensor of shape (4, 4)) of distribution MultivariateNormal(loc: torch.Size([4]), covariance_matrix: torch.Size([4, 4])) to satisfy the constraint PositiveDefinite(), but found invalid values:
tensor([[ 0.1483, -0.0308, 0.0772, -0.1051],
[-0.0308, 0.0064, -0.0160, 0.0218],
[ 0.0772, -0.0160, 0.0402, -0.0547],
[-0.1051, 0.0218, -0.0547, 0.0744]])
これは失敗してしまう。非常に小さいが負の固有値があるため
Numpyでlinalg.eigを使う
import numpy as np
data = np.random.random([2,4])
data_mean = data.mean(0)
data_cov = np.cov(data.T)
print(data_cov)
print(np.linalg.eig(data_cov))
m = np.random.multivariate_normal(data_mean, data_cov)
print(m)
[[ 0.02759016 0.01754446 -0.01297558 -0.02720648]
[ 0.01754446 0.01115645 -0.00825111 -0.01730049]
[-0.01297558 -0.00825111 0.00610238 0.01279513]
[-0.02720648 -0.01730049 0.01279513 0.02682814]]
(array([ 0.00000000e+00, 7.16771316e-02, 2.29609551e-18, -9.74180119e-19]), array([[-0.78426857, -0.62042148, -0.29677755, 0.28141888],
[ 0.31210066, -0.39452332, -0.54408298, 0.16195134],
[-0.23082416, 0.29178251, -0.72206044, -0.64096091],
[-0.48397954, 0.61179368, -0.30745003, 0.69551728]]))
[0.35824293 0.05329781 0.77522495 0.73602538]
numpy is OK
Covariant Matrixの性質
共分散行列は正定値であるべき
-
ValueError: Expected parameter covariance_matrix
add jitter to cov matrix
a possible work around might be to initialize the MultivariateNormal with
scale_tril = torch.linalg.cholesky(...) instead of covariance_matrix=cov
torch.linalg.cholesky()を使ってみる
data = torch.rand(2,4)
data_mean = data.mean(0)
data_cov = torch.cov(data.T)
print(data_cov)
print(torch.linalg.eig(data_cov))
L = torch.linalg.cholesky(data_cov)
m = MultivariateNormal(data_mean, L)
print(L)
print(m)
tensor([[ 0.0024, -0.0246, 0.0160, -0.0086],
[-0.0246, 0.2564, -0.1668, 0.0895],
[ 0.0160, -0.1668, 0.1085, -0.0582],
[-0.0086, 0.0895, -0.0582, 0.0312]])
torch.return_types.linalg_eig(
eigenvalues=tensor([ 0.0000e+00+0.j, 3.9844e-01+0.j, 3.4086e-09+0.j, -1.4140e-09+0.j]),
eigenvectors=tensor([[-0.9970+0.j, 0.0769+0.j, 0.0741+0.j, 0.7752+0.j],
[-0.0619+0.j, -0.8021+0.j, -0.3808+0.j, 0.3649+0.j],
[ 0.0402+0.j, 0.5218+0.j, -0.1050+0.j, 0.5043+0.j],
[-0.0216+0.j, -0.2800+0.j, 0.9157+0.j, 0.1074+0.j]]))
---------------------------------------------------------------------------
_LinAlgError Traceback (most recent call last)
/tmp/ipykernel_146999/2151643132.py in <module>
4 print(data_cov)
5 print(torch.linalg.eig(data_cov))
----> 6 L = torch.linalg.cholesky(data_cov)
7 m = MultivariateNormal(data_mean, L)
8 print(L)
_LinAlgError: linalg.cholesky: The factorization could not be completed because the input is not positive-definite (the leading minor of order 2 is not positive-definite).
結局Covarianceは正定値でないといけないはずだ
なので固有値を正にする
data = torch.rand(2,4)
data_mean = data.mean(0)
data_cov = torch.cov(data.T)
print(data_cov)
eigenvalues,eigenvectors=torch.linalg.eig(data_cov)
print(eigenvalues)
print(eigenvectors)
eigenvectors=eigenvectors.real
eigenvalues=torch.maximum(torch.ones_like(eigenvalues.real)*1e-8,eigenvalues.real)
print(eigenvalues)
data_cov=eigenvectors@torch.diag(eigenvalues)@eigenvectors.T
print(data_cov)
L = torch.linalg.cholesky(data_cov)
m = MultivariateNormal(data_mean, data_cov)
print(L)
m.sample()
tensor([[ 0.0004, -0.0007, -0.0020, -0.0031],
[-0.0007, 0.0011, 0.0033, 0.0050],
[-0.0020, 0.0033, 0.0093, 0.0142],
[-0.0031, 0.0050, 0.0142, 0.0216]])
tensor([ 0.0000e+00+0.j, 3.2509e-02+0.j, 1.1340e-10+0.j, -5.3808e-10+0.j])
tensor([[-0.9932+0.j, 0.1164+0.j, -0.7516+0.j, 0.0519+0.j],
[-0.0220+0.j, -0.1876+0.j, -0.5651+0.j, -0.0923+0.j],
[-0.0629+0.j, -0.5361+0.j, -0.2733+0.j, -0.8173+0.j],
[-0.0955+0.j, -0.8147+0.j, 0.2026+0.j, 0.5665+0.j]])
tensor([1.0000e-08, 3.2509e-02, 1.0000e-08, 1.0000e-08])
tensor([[ 0.0004, -0.0007, -0.0020, -0.0031],
[-0.0007, 0.0011, 0.0033, 0.0050],
[-0.0020, 0.0033, 0.0093, 0.0142],
[-0.0031, 0.0050, 0.0142, 0.0216]])
tensor([[ 2.0995e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[-3.3827e-02, 2.4065e-04, 0.0000e+00, 0.0000e+00],
[-9.6663e-02, 5.8782e-04, 1.0742e-04, 0.0000e+00],
[-1.4690e-01, 8.4981e-04, 9.4683e-05, 1.7149e-04]])
tensor([0.1913, 0.6402, 0.1010, 0.2775])
成功したが、画像だとまだ失敗する
再挑戦
loader = DataLoader(dataset, batch_size = 64, shuffle = True, num_workers = 0)
for img, class_lbl in loader:
batchnum=img.shape[0]
print("batch num %d"%batchnum)
print("inpuit img ",img.shape)
img=torch.flatten(img,start_dim=1)
print("inpuit img ",img.shape)
mean=torch.mean(img, dim=0)
print("mean ",mean.shape)
cov=torch.cov(img.T)
print("cov ",cov.shape)
print(cov-cov.T)
L=torch.linalg.cholesky(cov.T)
break
batch num 64
inpuit img torch.Size([64, 1, 28, 28])
inpuit img torch.Size([64, 784])
mean torch.Size([784])
cov torch.Size([784, 784])
tensor([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]])
---------------------------------------------------------------------------
_LinAlgError Traceback (most recent call last)
/tmp/ipykernel_7598/1244819073.py in <module>
11 print("cov ",cov.shape)
12 print(cov-cov.T)
---> 13 L=torch.linalg.cholesky(cov.T)
14 break
_LinAlgError: linalg.cholesky: The factorization could not be completed because the input is not positive-definite (the leading minor of order 1 is not positive-definite).
import numpy as np
for img, class_lbl in loader:
batchnum=img.shape[0]
print("batch num %d"%batchnum)
img=np.array(img[:,0])
print("inpuit img ",img.shape)
mean=img.mean(axis=0)
print("mean ",mean.shape)
img=img.reshape(img.shape[0],img.shape[1]*img.shape[2])
print("inpuit img ",img.shape)
cov=np.cov(img.T)
print("cov ",cov.shape)
eigenvalues,eigenvectors=torch.linalg.eig(data_cov)
L=np.linalg.cholesky(cov)
break
batch num 64
inpuit img (64, 28, 28)
mean (28, 28)
inpuit img (64, 784)
cov (784, 784)
[[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
...
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]]
---------------------------------------------------------------------------
LinAlgError Traceback (most recent call last)
/tmp/ipykernel_7598/1401271810.py in <module>
12 print("cov ",cov.shape)
13 print(cov-cov.T)
---> 14 L=np.linalg.cholesky(cov)
15 break
16
~/.local/lib/python3.8/site-packages/numpy/core/overrides.py in cholesky(*args, **kwargs)
~/.local/lib/python3.8/site-packages/numpy/linalg/linalg.py in cholesky(a)
761 t, result_t = _commonType(a)
762 signature = 'D->D' if isComplexType(t) else 'd->d'
--> 763 r = gufunc(a, signature=signature, extobj=extobj)
764 return wrap(r.astype(result_t, copy=False))
765
~/.local/lib/python3.8/site-packages/numpy/linalg/linalg.py in _raise_linalgerror_nonposdef(err, flag)
89
90 def _raise_linalgerror_nonposdef(err, flag):
---> 91 raise LinAlgError("Matrix is not positive definite")
92
93 def _raise_linalgerror_eigenvalues_nonconvergence(err, flag):
LinAlgError: Matrix is not positive definite
コレスキー分解に失敗
対角化して共分散行列の半分を求める
参考
eps=1e-4
for idx, (img, class_lbl) in enumerate(loop):
batchnum=img.shape[0]
print("batch num %d"%batchnum)
print("inpuit img ",img.shape)
img=torch.flatten(img,start_dim=1)
print("inpuit img ",img.shape)
mean=torch.mean(img, dim=0)
print("mean ",mean.shape)
cov=torch.cov(img.T)
evalue,evec=torch.linalg.eig(cov)
evec=evec.real
A=evec@torch.diag(torch.sqrt(evalue.real))
#MultivariateNormal
a=A@torch.rand_like(mean)+mean
print("MultivariateNormal sample ",a.shape)
break
batch num 64
inpuit img torch.Size([64, 1, 28, 28])
inpuit img torch.Size([64, 784])
mean torch.Size([784])
MultivariateNormal sample torch.Size([784])
MNISTにたいしても成功
その他
Multi GPUでの使用
Discussion