🧠

【PyTorch】Datasetとtransformsを自作する

2022/03/08に公開約2,900字

はじめに

今回は深層学習 (機械学習) で必ずと言って良い程登場するDatasettransformsについて自作していきます.

実際に私が使用していた自作のデータセットコードを添付します.

https://github.com/a5chin/AMED/blob/ee0a359df47ffe2c945e599c9f22c243583942b9/amed/dataset/AMEDDataset.py#L12-L36

最終的に理解できると認識して読んで頂けたらと思います.

環境

PC MacBook Pro (16-inch, 2019)
OS Monterey
CPU 2.3 GHz 8コアIntel Core i9
メモリ 16GB
Python 3.9

ライブラリ

今回の記事で用いるライブラリとバージョンをまとめますが,特に気にせず

terminal
pip install numpy torch torchvision

で問題ないかと思います.

ライブラリ バージョン
numpy 1.21.2
torch 1.10.1
torchvision 0.11.2

もしエラーを吐かれてしまい上手く動作しなかった場合には,上記のバージョンを指定してinstallしてみてください.

transforms

学習に必要なデータ数が少ないために,データオーギュメンテーション (データの水増し) を行うときに用いられることが多いです.

transforms.py
import torch
import numpy as np


class MyTransforms:
    def __init__(self) -> None:
        pass

    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        # TODO: fix to original transforms
        x = torch.from_numpy(x.astype(np.float32))  # example
        return x

必須関数

  • __call__
    • 特殊メソッドの1つ
    • インスタンスに引数を与えた時に呼ばれる

Dataset

オリジナルのデータを用いる時に必ず使用します.ImageFolder等を使用することもありますが,torch.utils.data.Datasetが継承されています.

dataset.py
from torch.utils.data import Dataset
from pathlib import Path
from typing import List, Tuple


class MyDataset(Dataset):
    def __init__(self, root: str, transforms) -> None:
        super().__init__()
        self.transforms = transforms
        # TODO: fix to original data
        self.data = Path(root).glob("**/*.json")

    # ここで取り出すデータを指定している
    def __getitem__(
        self,
        index: int
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        data = self.data[index]["data"]
        label = self.data[index]["label"]

        # データの変形 (transforms)
        data = self.transforms(data)

        return data, label

    # この method がないと DataLoader を呼び出す際にエラーを吐かれる
    def __len__(self) -> int:
        return len(self.data)

必須関数

  • __getitem__
    • データを取り出す時に呼び出される関数
    • indexlen()で取得した長さだけ繰り返される
  • __len__
    • len()を使用する時に呼び出される関数

DataLoader

今まで作ってきたtransformsDatasetを用いてDataLoderを作成します.データを取り出す際にはfor文を利用します.

main.py
from torch.utils.data import DataLoader
from torchvision import transforms


from .dataset import MyDataset
from .transforms import MyTransforms

transforms = transforms.Compose([
    transfors.ToTensor(),
    MyTransforms()
])

dataset = MyDataset(
    # TODO: fix to your original
    root="path to dataset's root",
    transforms=transforms
)

dataloader = DataLoader(
    dataset=dataset,
    batch_size=64,
    shuffle=True
)

深層学習例

おまけとして深層学習のイメージを明示しておきます.

main.py
for data, label in dataloader:
    model.train()
    preds = model(data)
    loss = criterion(preds, label)
    loss.backward()

更に詳しい例は以下のリンクから↓

需要があればリンク先の解説も出そうと思います!

GitHubで編集を提案

Discussion

ログインするとコメントできます