Open1
pytorch dataset and dataloader
自分のデータでDeepLearningやりたい時に、先ずはPytorchがどうやってデータを読み込みするのが一番です。
データセットについて以下Data sizeとBatch概念があります。
- Data size(すべてのデータ)
DeepLearningのデータセットは大体大きくてメモリとかの原因で、一回で読み込まないことが多い。
少しずつ読みながらtrainingが普通。
- Batch(一回分の読み込むデータ)
Datasetを定義方法
torch.utils.data
のDataset
継承し
基本的に__init__
__getitem__
__len__
関数をオーバーライドをするでいい。
-
__getitem__
データとラベル が要求されたときに返す -
__len__
データセットのデータ数が要求されたときに返す
import torch
from torch.utils.data import Dataset, DataLoader
from dataset_pre import load_data
class MyDataset(Dataset):
def __init__(self):
base_path = './data'
data, label = load_data(base_path)
self._x = data
self._y = label
self._len = len(data)
def __getitem__(self, item):
return self._x[item], self._y[item]
def __len__(self):
return self._len
dataset = MyDataset()
print(len(dataset))
first = next(iter(dataset))
print(first)
dataloader = DataLoader(dataset, batch_size=3, shuffle=True, drop_last=True, num_workers=0)
n = 0
for data_val, label_val in dataloader:
print('x:', data_val)
print('y:', label_val)
n += 1
print('iteration:', n)