😊
pytorchでバッチ分だけ複製された多変量正規分布を作る
pytorchでバッチ分だけ複製された多変量正規分布を作る
の続き
import torch
import torch.nn as nn
import torchvision as tv
import torchvision.transforms as T
from torch.utils.data import DataLoader
/usr/local/lib/python3.8/dist-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: '/usr/local/lib/python3.8/dist-packages/torchvision/image.so: undefined symbol: _ZN3c104cuda9SetDeviceEi'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?
warn(
画像データ(MNIST)の各ピクセルの共分散行列を作って、それをバッチ分複製する
tr = T.Compose([T.ToTensor()])
dataset = tv.datasets.MNIST('data', True, transform = tr, download = True)
loader = DataLoader(dataset, batch_size = 64, shuffle = True, num_workers = 0)
かなり試行錯誤したが正解は以下のようになる
An= A.unsqueeze(0).repeat(batchnum,1,1)
unsqueezeで次元を増やしその方向にrepeatする
多変量正規分布の作り方は
を参照
for img, class_lbl in loader:
batchnum=img.shape[0]
print("batch size ",batchnum)
print("inpuit img ",img.shape)
img=torch.flatten(img,start_dim=1)
print("inpuit img(reshape) ",img.shape)
mean=torch.mean(img, dim=0)
print("mean ",mean.shape)
#共分散行列 cov,A
cov=torch.cov(img.T)
evalue,evec=torch.linalg.eig(cov)
evec=evec.real
A=evec@torch.diag(torch.sqrt(evalue.real))
print(mean.repeat(batchnum,1).shape)
print("A",A.shape)
#共分散行列のバッチ数だけ複製したもの
An= A.unsqueeze(0).repeat(batchnum,1,1)
print("An",An.shape)
#(cov@random+mean)
# https://blog.recyclebin.jp/archives/4077
a=An@torch.rand_like(mean)+mean.repeat(batchnum,1)
print("An@rand",a.shape)
#MultivariateNormal dist.
print("MultivariateNormal sample ",a.shape)
break
batch size 64
inpuit img torch.Size([64, 1, 28, 28])
inpuit img(reshape) torch.Size([64, 784])
mean torch.Size([784])
torch.Size([64, 784])
A torch.Size([784, 784])
An torch.Size([64, 784, 784])
An@rand torch.Size([64, 784])
MultivariateNormal sample torch.Size([64, 784])
A = torch.tensor([2, 3, 5])
print(A.shape)
print(A.repeat(5).shape)
print(A.repeat(3, 5).shape)
print(A.repeat(3))
print(A.repeat(3, 1))
B =A.repeat(3, 1, 2).size()
print(B)
torch.Size([3])
torch.Size([15])
torch.Size([3, 15])
tensor([2, 3, 5, 2, 3, 5, 2, 3, 5])
tensor([[2, 3, 5],
[2, 3, 5],
[2, 3, 5]])
torch.Size([3, 1, 6])
Discussion