😊

2023/11/20に公開

# pytorchでバッチ分だけ複製された多変量正規分布を作る

の続き

``````import torch
import torch.nn as nn
import torchvision as tv
import torchvision.transforms as T
``````
``````/usr/local/lib/python3.8/dist-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: '/usr/local/lib/python3.8/dist-packages/torchvision/image.so: undefined symbol: _ZN3c104cuda9SetDeviceEi'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?
warn(
``````

``````tr = T.Compose([T.ToTensor()])
loader = DataLoader(dataset, batch_size = 64, shuffle = True, num_workers = 0)
``````

かなり試行錯誤したが正解は以下のようになる

An= A.unsqueeze(0).repeat(batchnum,1,1)

unsqueezeで次元を増やしその方向にrepeatする

を参照

``````for img, class_lbl in loader:
batchnum=img.shape[0]
print("batch size ",batchnum)
print("inpuit img ",img.shape)
img=torch.flatten(img,start_dim=1)
print("inpuit img(reshape) ",img.shape)
mean=torch.mean(img, dim=0)
print("mean ",mean.shape)

#共分散行列 cov,A
cov=torch.cov(img.T)
evalue,evec=torch.linalg.eig(cov)
evec=evec.real
A=evec@torch.diag(torch.sqrt(evalue.real))

print(mean.repeat(batchnum,1).shape)
print("A",A.shape)

#共分散行列のバッチ数だけ複製したもの
An= A.unsqueeze(0).repeat(batchnum,1,1)
print("An",An.shape)

#(cov@random+mean)
# https://blog.recyclebin.jp/archives/4077

a=An@torch.rand_like(mean)+mean.repeat(batchnum,1)
print("An@rand",a.shape)
#MultivariateNormal dist.
print("MultivariateNormal sample ",a.shape)
break
``````
``````batch size  64
inpuit img  torch.Size([64, 1, 28, 28])
inpuit img(reshape)  torch.Size([64, 784])
mean  torch.Size([784])
torch.Size([64, 784])
A torch.Size([784, 784])
An torch.Size([64, 784, 784])
An@rand torch.Size([64, 784])
MultivariateNormal sample  torch.Size([64, 784])
``````
``````A = torch.tensor([2, 3, 5])
print(A.shape)
print(A.repeat(5).shape)
print(A.repeat(3, 5).shape)
print(A.repeat(3))
print(A.repeat(3, 1))
B =A.repeat(3, 1, 2).size()
print(B)
``````
``````torch.Size([3])
torch.Size([15])
torch.Size([3, 15])
tensor([2, 3, 5, 2, 3, 5, 2, 3, 5])
tensor([[2, 3, 5],
[2, 3, 5],
[2, 3, 5]])
torch.Size([3, 1, 6])
``````