🦜

【PyTorch】Dataset解説

2024/04/03に公開

今回はPyTorchのDatasetについて解説します。

1. Datasetとは

PyTrochにおけるDatasetは、機械学習モデルに適切な流れでデータを供給する枠組みです。
主にDataLoaderと共に使用され、これらによってデータ拡張やバッチ処理などの拡張が行いやすくなります。

2. 定義

Datasetは、PyTorchのtorch.utils.data.Datasetクラスを継承して定義します。
また以下に示すメソッドを定義するように指定されています。

必要なメソッド

  • __init__(self)
    初期実行関数です。Datasetを定義する際に必要な情報を受け取ります。

  • __len__(self)
    データ全体の数を返す関数です。例えば10000枚の画像を学習に使用する場合は10000を返します。

  • __getitem__(self, index)
    指定されたindexに対応するデータと正解ラベル(ターゲット)を返します。

3. 例

実際の使い方を見ていきましょう。
今回は簡単に、犬と猫を分類する画像認識タスクを考えます。

・分類タスクのDataset

import numpy as np
import torch
from torch.utils.data import Dataset

class Train_Dataset(Dataset):
    def __init__(self, target_list):
        # filepathのリストを作成
        self.__make_img_path_list()
        self.target_list = target_list
        
    def __len__(self):
        # データの総数を返す関数
        return len(self.img_path_list)
    def __getitem__(self, index: int):
        # indexを指定して、データと正解ラベル(ターゲット)を返す関数
        return self.__img_from_path(index)
    # filepathのリストを作成
    def __make_img_path_list(self):
        DATA_DIRECTORY = '/etc'
        self.img_path_list = []
        # numpy.arrayのファイルパスを指定してリスト化
        for path_id in range(10):
                img_path = DATA_DIRECTORY + f"/img_{path_id}.npy"
                self.img_path_list.append(img_path)
            
    #spectrogramのpathを指定して、spectrogramに前処理を実施して返す関数を定義
    def __img_from_path(self,index):
        # 指定indexのpathを取得
        img_path = self.img_path_list[index]
        # 指定indexのlabel_id(データ識別子)を取得
        target = self.target_list[index]
        # pathから画像を取得
        img = np.load(img_path)

        return img, target 

target_list = ['dog','dog','cat','dog','cat','cat','cat','dog','dog','cat']
Sample_Dataset = Train_Dataset(target_list)

以下で順番に解説していきます。

3.1 初期化(init)

__init__()関数について解説します。

画像データのデータセットを作る場合、全画像のファイルパスが含まれたリストを作り、その数をデータ総数とすると理解しやすいです。

・全画像(犬や猫の10枚の画像)のファイルパスが含まれたリスト

# ファイルパスのリスト
img_path_list = ['/etc/img_0.npy', '/etc/img_2.npy', ... ,'/etc/img_9.npy']

上記の例では、'__make_img_path_list()'関数で全画像ファイルパスのリストを作成しており、これを__init__()関数で実行しています。

また、self.target_list = target_listで正解ラベルの配列を格納しています。

3.2 len

次に__len__()関数ではデータの総数を返す必要があります。
__init__()でファイルパスのリストを作っているので、この総数を返すようにすれば完了です。

3.3 getitem

最後は__getitem__()関数で、indexに対応したデータを返す必要があります。

今回は__img_from_path()関数でその動作を定義しています。コードを確認してみましょう。

def __img_from_path(self,index):
    img_path = self.img_path_list[index]
    target = self.target_list[index]
    img = np.load(img_path)
    
    return img, target 

具体的な動作は以下のようになっています。

  1. indexに対応するファイルパスimg_pathを取得
  2. 正解ラベルの配列から対応する値'dog' or 'cat'を取得
  3. ファイルパスimg_pathから画像imgを取得
  4. return img, targetで画像と正解ラベルを返す

3.4 定義

最後に、target配列を指定してDatasetを定義すれば完了です。

・Datasetの定義

target_list = ['dog','dog','cat','dog','cat','cat','cat','dog','dog','cat']
Sample_Dataset = Train_Dataset(target_list)

これを利用することで、効率的に機械学習モデルを構築することができます。

まとめ

今回はPyTrochのDatasetについて解説しました。
次回はDataLoaderを解説する予定です。

Discussion