Open1

pytorch dataset and dataloader

北宇治高校北宇治高校

自分のデータでDeepLearningやりたい時に、先ずはPytorchがどうやってデータを読み込みするのが一番です。

データセットについて以下Data sizeとBatch概念があります。

  • Data size(すべてのデータ)

DeepLearningのデータセットは大体大きくてメモリとかの原因で、一回で読み込まないことが多い。
少しずつ読みながらtrainingが普通。

  • Batch(一回分の読み込むデータ)

Datasetを定義方法

torch.utils.dataDataset継承し
基本的に__init__ __getitem__ __len__関数をオーバーライドをするでいい。

  • __getitem__
    データとラベル が要求されたときに返す
  • __len__
    データセットのデータ数が要求されたときに返す
import torch
from torch.utils.data import Dataset, DataLoader
from dataset_pre import load_data


class MyDataset(Dataset):
    
    def __init__(self):
        base_path = './data'
    
        data, label = load_data(base_path)
        
        self._x = data
        self._y = label
        self._len = len(data)
        
    def __getitem__(self, item):
        return self._x[item], self._y[item]
    
    def __len__(self):
        return self._len
        
dataset = MyDataset()
print(len(dataset))

first = next(iter(dataset))
print(first)

dataloader = DataLoader(dataset, batch_size=3, shuffle=True, drop_last=True, num_workers=0)

n = 0

for data_val, label_val in dataloader:
    print('x:', data_val)
    print('y:', label_val)
    n += 1

print('iteration:', n)