【PyTorch】Dataset解説
今回はPyTorchのDatasetについて解説します。
1. Datasetとは
PyTrochにおけるDatasetは、機械学習モデルに適切な流れでデータを供給する枠組みです。
主にDataLoaderと共に使用され、これらによってデータ拡張やバッチ処理などの拡張が行いやすくなります。
2. 定義
Datasetは、PyTorchのtorch.utils.data.Dataset
クラスを継承して定義します。
また以下に示すメソッドを定義するように指定されています。
必要なメソッド
-
__init__(self)
初期実行関数です。Datasetを定義する際に必要な情報を受け取ります。 -
__len__(self)
データ全体の数を返す関数です。例えば10000枚の画像を学習に使用する場合は10000を返します。 -
__getitem__(self, index)
指定されたindexに対応するデータと正解ラベル(ターゲット)を返します。
3. 例
実際の使い方を見ていきましょう。
今回は簡単に、犬と猫を分類する画像認識タスクを考えます。
・分類タスクのDataset
import numpy as np
import torch
from torch.utils.data import Dataset
class Train_Dataset(Dataset):
def __init__(self, target_list):
# filepathのリストを作成
self.__make_img_path_list()
self.target_list = target_list
def __len__(self):
# データの総数を返す関数
return len(self.img_path_list)
def __getitem__(self, index: int):
# indexを指定して、データと正解ラベル(ターゲット)を返す関数
return self.__img_from_path(index)
# filepathのリストを作成
def __make_img_path_list(self):
DATA_DIRECTORY = '/etc'
self.img_path_list = []
# numpy.arrayのファイルパスを指定してリスト化
for path_id in range(10):
img_path = DATA_DIRECTORY + f"/img_{path_id}.npy"
self.img_path_list.append(img_path)
#spectrogramのpathを指定して、spectrogramに前処理を実施して返す関数を定義
def __img_from_path(self,index):
# 指定indexのpathを取得
img_path = self.img_path_list[index]
# 指定indexのlabel_id(データ識別子)を取得
target = self.target_list[index]
# pathから画像を取得
img = np.load(img_path)
return img, target
target_list = ['dog','dog','cat','dog','cat','cat','cat','dog','dog','cat']
Sample_Dataset = Train_Dataset(target_list)
以下で順番に解説していきます。
3.1 初期化(init)
__init__()関数について解説します。
画像データのデータセットを作る場合、全画像のファイルパスが含まれたリストを作り、その数をデータ総数とすると理解しやすいです。
・全画像(犬や猫の10枚の画像)のファイルパスが含まれたリスト
# ファイルパスのリスト
img_path_list = ['/etc/img_0.npy', '/etc/img_2.npy', ... ,'/etc/img_9.npy']
上記の例では、'__make_img_path_list()'関数で全画像ファイルパスのリストを作成しており、これを__init__()関数で実行しています。
また、self.target_list = target_list
で正解ラベルの配列を格納しています。
3.2 len
次に__len__()関数ではデータの総数を返す必要があります。
__init__()でファイルパスのリストを作っているので、この総数を返すようにすれば完了です。
3.3 getitem
最後は__getitem__()関数で、indexに対応したデータを返す必要があります。
今回は__img_from_path()
関数でその動作を定義しています。コードを確認してみましょう。
def __img_from_path(self,index):
img_path = self.img_path_list[index]
target = self.target_list[index]
img = np.load(img_path)
return img, target
具体的な動作は以下のようになっています。
- indexに対応するファイルパス
img_path
を取得 - 正解ラベルの配列から対応する値
'dog'
or'cat'
を取得 - ファイルパス
img_path
から画像img
を取得 -
return img, target
で画像と正解ラベルを返す
3.4 定義
最後に、target配列を指定してDatasetを定義すれば完了です。
・Datasetの定義
target_list = ['dog','dog','cat','dog','cat','cat','cat','dog','dog','cat']
Sample_Dataset = Train_Dataset(target_list)
これを利用することで、効率的に機械学習モデルを構築することができます。
まとめ
今回はPyTrochのDatasetについて解説しました。
次回はDataLoaderを解説する予定です。
Discussion