😊

pytorchでバッチ分だけ複製された多変量正規分布を作る

2023/11/20に公開

pytorchでバッチ分だけ複製された多変量正規分布を作る

https://zenn.dev/xiangze/articles/16d94225988287
の続き

import torch
import torch.nn as nn
import torchvision as tv
import torchvision.transforms as T
from torch.utils.data import DataLoader
/usr/local/lib/python3.8/dist-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: '/usr/local/lib/python3.8/dist-packages/torchvision/image.so: undefined symbol: _ZN3c104cuda9SetDeviceEi'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?
  warn(

画像データ(MNIST)の各ピクセルの共分散行列を作って、それをバッチ分複製する

tr = T.Compose([T.ToTensor()])
dataset = tv.datasets.MNIST('data', True, transform = tr, download = True)
loader = DataLoader(dataset, batch_size = 64, shuffle = True, num_workers = 0)

かなり試行錯誤したが正解は以下のようになる

An= A.unsqueeze(0).repeat(batchnum,1,1)

unsqueezeで次元を増やしその方向にrepeatする

https://stackoverflow.com/questions/57896357/how-to-repeat-tensor-in-a-specific-new-dimension-in-pytorch

多変量正規分布の作り方は

https://blog.recyclebin.jp/archives/4077

を参照

for img, class_lbl in loader:
    batchnum=img.shape[0]
    print("batch size ",batchnum)
    print("inpuit img ",img.shape)
    img=torch.flatten(img,start_dim=1)
    print("inpuit img(reshape) ",img.shape)
    mean=torch.mean(img, dim=0)
    print("mean ",mean.shape)
    
    #共分散行列 cov,A
    cov=torch.cov(img.T)
    evalue,evec=torch.linalg.eig(cov)
    evec=evec.real
    A=evec@torch.diag(torch.sqrt(evalue.real))
    
    print(mean.repeat(batchnum,1).shape)
    print("A",A.shape)
    
    #共分散行列のバッチ数だけ複製したもの
    An= A.unsqueeze(0).repeat(batchnum,1,1)
    print("An",An.shape)

    #(cov@random+mean)
    # https://blog.recyclebin.jp/archives/4077
    
    a=An@torch.rand_like(mean)+mean.repeat(batchnum,1)
    print("An@rand",a.shape)
    #MultivariateNormal dist.
    print("MultivariateNormal sample ",a.shape)
    break
batch size  64
inpuit img  torch.Size([64, 1, 28, 28])
inpuit img(reshape)  torch.Size([64, 784])
mean  torch.Size([784])
torch.Size([64, 784])
A torch.Size([784, 784])
An torch.Size([64, 784, 784])
An@rand torch.Size([64, 784])
MultivariateNormal sample  torch.Size([64, 784])
A = torch.tensor([2, 3, 5])
print(A.shape)
print(A.repeat(5).shape)
print(A.repeat(3, 5).shape)
print(A.repeat(3))
print(A.repeat(3, 1))
B =A.repeat(3, 1, 2).size()
print(B)
torch.Size([3])
torch.Size([15])
torch.Size([3, 15])
tensor([2, 3, 5, 2, 3, 5, 2, 3, 5])
tensor([[2, 3, 5],
        [2, 3, 5],
        [2, 3, 5]])
torch.Size([3, 1, 6])

Discussion