🧠
【PyTorch】Datasetとtransformsを自作する
はじめに
今回は深層学習 (機械学習) で必ずと言って良い程登場するDatasetとtransformsについて自作していきます.
実際に私が使用していた自作のデータセットコードを添付します.
最終的に理解できると認識して読んで頂けたらと思います.
環境
PC | MacBook Pro (16-inch, 2019) |
---|---|
OS | Monterey |
CPU | 2.3 GHz 8コアIntel Core i9 |
メモリ | 16GB |
Python | 3.9 |
ライブラリ
今回の記事で用いるライブラリとバージョンをまとめますが,特に気にせず
terminal
pip install numpy torch torchvision
で問題ないかと思います.
ライブラリ | バージョン |
---|---|
numpy | 1.21.2 |
torch | 1.10.1 |
torchvision | 0.11.2 |
もしエラーを吐かれてしまい上手く動作しなかった場合には,上記のバージョンを指定してinstall
してみてください.
transforms
学習に必要なデータ数が少ないために,データオーギュメンテーション (データの水増し) を行うときに用いられることが多いです.
transforms.py
import torch
import numpy as np
class MyTransforms:
def __init__(self) -> None:
pass
def __call__(self, x: torch.Tensor) -> torch.Tensor:
# TODO: fix to original transforms
x = torch.from_numpy(x.astype(np.float32)) # example
return x
必須関数
-
__call__
- 特殊メソッドの1つ
- インスタンスに引数を与えた時に呼ばれる
Dataset
オリジナルのデータを用いる時に必ず使用します.ImageFolder
等を使用することもありますが,torch.utils.data.Dataset
が継承されています.
dataset.py
from torch.utils.data import Dataset
from pathlib import Path
from typing import List, Tuple
class MyDataset(Dataset):
def __init__(self, root: str, transforms) -> None:
super().__init__()
self.transforms = transforms
# TODO: fix to original data
self.data = Path(root).glob("**/*.json")
# ここで取り出すデータを指定している
def __getitem__(
self,
index: int
) -> Tuple[torch.Tensor, torch.Tensor]:
data = self.data[index]["data"]
label = self.data[index]["label"]
# データの変形 (transforms)
data = self.transforms(data)
return data, label
# この method がないと DataLoader を呼び出す際にエラーを吐かれる
def __len__(self) -> int:
return len(self.data)
必須関数
-
__getitem__
- データを取り出す時に呼び出される関数
-
index
がlen()
で取得した長さだけ繰り返される
-
__len__
-
len()
を使用する時に呼び出される関数
-
DataLoader
今まで作ってきたtransforms
とDataset
を用いてDataLoder
を作成します.データを取り出す際にはfor
文を利用します.
main.py
from torch.utils.data import DataLoader
from torchvision import transforms
from .dataset import MyDataset
from .transforms import MyTransforms
transforms = transforms.Compose([
transfors.ToTensor(),
MyTransforms()
])
dataset = MyDataset(
# TODO: fix to your original
root="path to dataset's root",
transforms=transforms
)
dataloader = DataLoader(
dataset=dataset,
batch_size=64,
shuffle=True
)
深層学習例
おまけとして深層学習のイメージを明示しておきます.
main.py
for data, label in dataloader:
model.train()
preds = model(data)
loss = criterion(preds, label)
loss.backward()
更に詳しい例は以下のリンクから↓
需要があればリンク先の解説も出そうと思います!
Discussion