PytorchのDatasetを標準化する
PyTorchのデータセット(Dataset)を標準化する方法を記します.なお,本記事はある程度PyTorchの使い方に慣れている人向けに書かれています.
標準化とは
標準化とは,特徴量である数値を,平均0,標準偏差1になるように変換することです.PyTorchで実装するような深層学習に限った話ではなく,機械学習全般の一般的な話ですね.標準化自体の説明は他の記事に譲るとして,本記事ではPyTorchでの実装方法について書いていきます.
標準化についての説明は,以下の記事が参考になるかと思います.
PyTorchでの標準化のモチベーション
機械学習において標準化は,すごくポピュラーな前処理です.そのため,メジャーな機械学習ライブラリである scikit-learn では当たり前のように実装されています.Pipelineという概念を使って,簡単に実装や,機械学習モデルを作成することができます.
しかし,PyTorchでは標準の機能で手軽に実現することはできません.これはそもそもPyTorchが深層学習に特化したライブラリであることや,Pipelineのような概念が存在しないことが一因として考えられます.PyTorchでは,DatasetやDataLoaderというクラスでデータを取り出しながら学習を行いますが,標準のクラスには標準化は対応していないため自分で実装する必要があります.標準化は,深層学習に光が当てられる以前の,人手で設計されたトラディショナルな特徴量に対して行うものです.そのため,深層学習ではそこまで重要視されていないのかもしれません.
とはいえ,PyTorchでも標準化したいという需要はあると筆者は考えています(筆者自身にはありました).トラディショナルな特徴量を使って,ちょっと凝ったニューラルネットワークを学習したいことはありますからね.scikit-learnのニューラルネットワークのクラス(MLPClassifier)には,PyTorchのような柔軟性はありません.具体的には,損失関数への工夫・ドロップアウト・Batch Normalization等の工夫を入れる余地がないのです.そのため,PyTorchを使って実装したくなることがあります.[1]
実装: StandardScalerSubset
StandardScalerSubsetと名付けたクラスを実装します.以下の通りです.
class StandardScalerSubset(Subset):
def __init__(self, dataset, indices,
mean=None, std=None, eps=10**-9):
super().__init__(dataset=dataset, indices=indices)
target_tensor = torch.stack([dataset[i][0] for i in indices])
target_tensor = target_tensor.to(torch.float64)
if mean is None:
self._mean = torch.mean(target_tensor, dim=0)
else:
self._mean = mean
if std is None:
self._std = torch.std(target_tensor, dim=0, unbiased=False)
else:
self._std = std
self._eps = eps
self.std.apply_(lambda x: max(x, self.eps)) # ゼロ割対策
def __getitem__(self, idx):
dataset_list = list(self.dataset[self.indices[idx]])
input = dataset_list[0]
dataset_list[0] = (input - self.mean) / self.std
return tuple(dataset_list)
@property
def mean(self):
return self._mean
@property
def std(self):
return self._std
@property
def eps(self):
return self._eps
この記事では,Subsetのような形式で利用できるように実装しました.Datasetインスタンスと,標準化の対象を指定するインデックスのリストを引数としてインスタンスを生成します.対象となったデータ間で平均と標準偏差を算出して保持しておき,データにアクセスがあったときに標準化します.meanとstdを自分で設定すれば,平均と標準偏差それぞれで自動算出されず,指定した値を用います.詳しくはデモンストレーションをご覧ください.
以下,実装に関する補足説明です.
__init__メソッドの9行目にあるPyTorchのstd関数は,デフォルトでは不偏標準偏差を返します.ここではunbiased引数をFalseにして,標本標準偏差を算出しています.筆者としてはどちらにするべきか悩みましたが,よりわかりやすいと考えたのでこのように実装しました.不偏標準偏差のほうが良い場合は,unbiasedをTrueに書き換えてください.
__getitem__メソッド内で,datasetから取得した要素をlistにしているのは,筆者の別記事のDatasetWithIndexを考えてのことです.通常のDatasetでは,データとラベルの2つの要素が返ってきますが,DatasetWithIndexはデータ・ラベル・インデクスの3つが返ってきます.ここではより柔軟に3個以上の要素が返ってくる場合でも大丈夫なように実装しています.DatasetWithIndexクラスについては以下の記事をご覧ください.
コードのダウンロード
本記事のコードは以下のページ(Github)からダウンロードできます.なお,MITライセンスで公開しております.改変・公開等ご自由にお使いください.
以降で紹介するデモンストレーションの実行コマンドは以下の通りです.
python StandardScalerSubset.py
Python 3.8.5と,以下のライブラリで動作確認をしています.
torch 1.7.1
torchvision 0.8.2
デモンストレーション
それでは,StandardScalerSubsetの使い方を示すために,デモンストレーションを行います.
以下のコードで実行できます.
class MyDataset(torch.utils.data.Dataset):
def __init__(self):
self.data = torch.tensor([[10, 100, 1000],
[20, 50, 1500],
[30, 150, 2500],
[15, 175, 1300]])
self.labels = tuple([1, 0, 1, 1])
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
dataset = MyDataset()
# 0,1,2を訓練用データとする.
train_sss = StandardScalerSubset(dataset, [0, 1, 2])
# テストデータは,訓練時の平均と標準偏差で標準化を行う.
test_sss = StandardScalerSubset(dataset, [3],
mean=train_sss.mean, std=train_sss.std)
print("Training data")
for i in range(len(train_sss)):
print(train_sss[i])
print()
print("Test data")
print(test_sss[0])
実行結果は以下のとおりです.
Training data
(tensor([-1.2247, 0.0000, -1.0690], dtype=torch.float64), 1)
(tensor([ 0.0000, -1.2247, -0.2673], dtype=torch.float64), 0)
(tensor([1.2247, 1.2247, 1.3363], dtype=torch.float64), 1)
Test data
(tensor([-0.6124, 1.8371, -0.5880], dtype=torch.float64), 1)
注意は以下の点です.
# テストデータは,訓練時の平均と標準偏差で標準化を行う.
test_sss = StandardScalerSubset(dataset, [3],
mean=train_sss.mean, std=train_sss.std)
テストデータの標準化は,訓練時の平均と標準偏差を使う必要があります.そのため,訓練データの平均(mean)と標準偏差(std)にアクセスして,それらをテストデータのStandardScalerSubsetを生成するときに引数として渡しています.
おわりに
PyTorchで前処理の標準化を実現する,StandardScalerSubsetクラスを実装して紹介しました.画像・自然言語・音声等以外のデータで,トラディショナルな特徴量のあるデータをニューラルネットワークで扱うときに役立つかと思います.
-
悪口っぽく書いてしまいましたが,筆者はscikit-learnも好きです.シンプルに使えていいですね. ↩︎
Discussion