PyTorchのDatasetからインデックスも取得する
PyTorchでの深層学習のコードで,Datasetからデータとラベルだけでなく,インデックスも取得する方法を記します.なお,本記事はある程度PyTorchの使い方に慣れている人向けに書かれています.
DatasetWithIndex クラスの実装
筆者が調べた限りでは,標準のPyTorchの機能だけでは実現できなかったので,以下のようなDatasetのラッパーを実装します.
class DatasetWithIndex:
def __init__(self, dataset):
self.dataset = dataset
def __getitem__(self, index):
data, label = self.dataset[index]
return data, label, index
def __len__(self):
return len(self.dataset)
@property
def classes(self):
return self.dataset.classes
コードのダウンロード
本記事のコードは以下のページ(Github)からダウンロードできます.なお,MITライセンスで公開しております.改変・公開等ご自由にお使いください.
以降で紹介するデモンストレーションの実行コマンドは以下の通りです.デモンストレーションのためにカレントディレクトリにMNISTがダウンロードされるのでご注意ください.
python DatasetWithIndex.py
Python 3.8.5と,以下のライブラリで動作確認をしています.
torch 1.7.1
torchvision 0.8.2
デモンストレーション1
使い方は,以下のように,任意のDatasetクラスのインスタンスを引数としてDatasetWithIndexクラスのインスタンスを生成します.(7行目)
あとは普通のDatasetと同じように使えます.
from torch.utils.data import DataLoader
from torchvision import transforms as tt
from torchvision.datasets import MNIST
dataset = MNIST(root='./', train=True, download=True,
transform=tt.Compose([tt.ToTensor()]))
dataset_with_index = DatasetWithIndex(dataset) # ★データセットをラップしている
data_loader = DataLoader(dataset_with_index, batch_size=4, shuffle=True)
# デモンストレーション1
## 一部データを取得し,あとで取得したインデックスで同じデータにアクセスできるか調べる.
input_list, label_list, index_list = [], [], []
for i, data in enumerate(data_loader):
inputs, labels, indices = data
input_list.extend(inputs)
label_list.extend(labels)
index_list.extend(indices)
if i >= 3:
break
for input, label, index in zip(input_list, label_list, index_list):
data = dataset_with_index[index]
# indexの辻褄があっているかを確認
assert (input == data[0]).all()
assert data[1] == label
print("label1 = {}, label2= {}".format(data[1], label))
print("len(dataset_with_index) = {}".format(len(dataset_with_index)))
print("dataset_with_index.classes = {}".format(dataset_with_index.classes))
実行結果は以下の様になっています.1つ目のループで,MNISTの画像, ラベル, インデックスを取得しています.2つ目のループで,1つ目のループで取得したインデックスを使って,Datasetにアクセスし,同じ画像とラベルを取得できるかをチェックしています(コード下から4,5行目のassert).
label1 = 8, label2 = 8
label1 = 9, label2 = 9
label1 = 5, label2 = 5
label1 = 4, label2 = 4
label1 = 6, label2 = 6
label1 = 6, label2 = 6
label1 = 9, label2 = 9
label1 = 2, label2 = 2
label1 = 2, label2 = 2
label1 = 8, label2 = 8
label1 = 1, label2 = 1
label1 = 1, label2 = 1
label1 = 6, label2 = 6
label1 = 6, label2 = 6
label1 = 8, label2 = 8
label1 = 3, label2 = 3
len(dataset_with_index) = 60000
dataset_with_index.classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
見やすさのためにラベルだけ表示しています.また,表示結果は乱数により毎回異なります.
デモンストレーション2: Subsetとの組み合わせ
Subsetと組み合わせることも出来ます.
ラップする順番によって,Subset上のインデックスを取得するか,オリジナルのインデックスを取得するかが違います.
# デモンストレーション2: Subset
from torch.utils.data import Subset
## Subset上のインデックスを取得する
## SubsetをDatasetWithIndexでラップする
subset1 = Subset(dataset, indices=[2, 1, 3, 5, 4])
subset_with_index = DatasetWithIndex(subset1)
print('index on a subset = {}'.format(subset_with_index[0][2]))
## 元のデータセットのインデックスを取得する
## DatasetWithIndexをSubsetでラップする
subset_with_raw_index = Subset(dataset_with_index, [2, 1, 3, 5, 4])
print('index on a raw dataset = {}'.format(subset_with_raw_index[0][2]))
実行結果は以下のようになります.いずれも同じデータ,ラベルにアクセスしていますが,返ってきているインデックスが異なっています.少々ややこしいかもしれません.
index on a subset = 0
label = 4
index on a raw dataset = 2
label = 4
おわりに
データとラベルだけでなく,インデックスも取得可能なDatasetのラッパーであるDatasetWithIndexクラスを実装して紹介しました.
オリジナルのアルゴリズムを実装したいときに使えるかと思います.
Discussion