🎃

PyTorchのDatasetからインデックスも取得する

4 min read

PyTorchでの深層学習のコードで,Datasetからデータとラベルだけでなく,インデックスも取得する方法を記します.なお,本記事はある程度PyTorchの使い方に慣れている人向けに書かれています.

DatasetWithIndex クラスの実装

筆者が調べた限りでは,標準のPyTorchの機能だけでは実現できなかったので,以下のようなDatasetのラッパーを実装します.

DatasetWithIndex.py
class DatasetWithIndex:
    def __init__(self, dataset):
        self.dataset = dataset

    def __getitem__(self, index):
        data, label = self.dataset[index]
        return data, label, index

    def __len__(self):
        return len(self.dataset)

    @property
    def classes(self):
        return self.dataset.classes

コードのダウンロード

本記事のコードは以下のページ(Github)からダウンロードできます.なお,MITライセンスで公開しております.改変・公開等ご自由にお使いください.

https://github.com/HidetoshiKawaguchi/tech-blog-codes/tree/main/20210619_pytorch-dataset-with-index

以降で紹介するデモンストレーションの実行コマンドは以下の通りです.デモンストレーションのためにカレントディレクトリにMNISTがダウンロードされるのでご注意ください.

python DatasetWithIndex.py

Python 3.8.5と,以下のライブラリで動作確認をしています.

torch             1.7.1
torchvision       0.8.2

デモンストレーション1

使い方は,以下のように,任意のDatasetクラスのインスタンスを引数としてDatasetWithIndexクラスのインスタンスを生成します.(7行目)
あとは普通のDatasetと同じように使えます.

DatasetWithOriginIndex.py
from torch.utils.data import DataLoader
from torchvision import transforms as tt
from torchvision.datasets import MNIST

dataset = MNIST(root='./', train=True, download=True,
                transform=tt.Compose([tt.ToTensor()]))
dataset_with_index = DatasetWithIndex(dataset) # ★データセットをラップしている
data_loader = DataLoader(dataset_with_index, batch_size=4, shuffle=True)

# デモンストレーション1
## 一部データを取得し,あとで取得したインデックスで同じデータにアクセスできるか調べる.
input_list, label_list, index_list = [], [], []
for i, data in enumerate(data_loader):
    inputs, labels, indices = data
    input_list.extend(inputs)
    label_list.extend(labels)
    index_list.extend(indices)
    if i >= 3:
        break
for input, label, index in zip(input_list, label_list, index_list):
    data = dataset_with_index[index]
    # indexの辻褄があっているかを確認
    assert (input == data[0]).all()
    assert data[1] == label
    print("label1 = {}, label2= {}".format(data[1], label))
print("len(dataset_with_index) = {}".format(len(dataset_with_index)))
print("dataset_with_index.classes = {}".format(dataset_with_index.classes))

実行結果は以下の様になっています.1つ目のループで,MNISTの画像, ラベル, インデックスを取得しています.2つ目のループで,1つ目のループで取得したインデックスを使って,Datasetにアクセスし,同じ画像とラベルを取得できるかをチェックしています(コード下から4,5行目のassert).

label1 = 8, label2 = 8
label1 = 9, label2 = 9
label1 = 5, label2 = 5
label1 = 4, label2 = 4
label1 = 6, label2 = 6
label1 = 6, label2 = 6
label1 = 9, label2 = 9
label1 = 2, label2 = 2
label1 = 2, label2 = 2
label1 = 8, label2 = 8
label1 = 1, label2 = 1
label1 = 1, label2 = 1
label1 = 6, label2 = 6
label1 = 6, label2 = 6
label1 = 8, label2 = 8
label1 = 3, label2 = 3
len(dataset_with_index) = 60000
dataset_with_index.classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']

見やすさのためにラベルだけ表示しています.また,表示結果は乱数により毎回異なります.

デモンストレーション2: Subsetとの組み合わせ

Subsetと組み合わせることも出来ます.
ラップする順番によって,Subset上のインデックスを取得するか,オリジナルのインデックスを取得するかが違います.

DatasetWithOriginIndex.py
# デモンストレーション2: Subset
from torch.utils.data import Subset

## Subset上のインデックスを取得する
## SubsetをDatasetWithIndexでラップする
subset1 = Subset(dataset, indices=[2, 1, 3, 5, 4])
subset_with_index = DatasetWithIndex(subset1)
print('index on a subset = {}'.format(subset_with_index[0][2]))

## 元のデータセットのインデックスを取得する
## DatasetWithIndexをSubsetでラップする
subset_with_raw_index = Subset(dataset_with_index, [2, 1, 3, 5, 4])
print('index on a raw dataset = {}'.format(subset_with_raw_index[0][2]))

実行結果は以下のようになります.いずれも同じデータ,ラベルにアクセスしていますが,返ってきているインデックスが異なっています.少々ややこしいかもしれません.

index on a subset = 0
label = 4
index on a raw dataset = 2
label = 4

おわりに

データとラベルだけでなく,インデックスも取得可能なDatasetのラッパーであるDatasetWithIndexクラスを実装して紹介しました.
オリジナルのアルゴリズムを実装したいときに使えるかと思います.

Discussion

ログインするとコメントできます