torchvisionのImageFolderで動的にクラスを指定してデータセットを作成する
要約
is_valid_fileオプションとallow_emptyオプションを組み合わせる。
内容
ImageFolder
torchvisionで提供される、画像データを読み込むのに便利なクラス。画像データが存在するルートフォルダのパスを与えればデータセットを生成してくれるほか、クラスごとにサブフォルダを分けておけば自動でクラスラベルを付与してくれる。
クラス増分学習を実行する場合など、クラスラベルを指定してデータセットを構成したい場合がある。いちいちデータの追加・削除を行うのは面倒なので、読み込み時に動的にクラスを指定したい。
is_valid_file / allow_emptyオプション
この場合、is_valid_fileオプションを使う。こんなオプション名だが、bool変数ではなくcallableなオブジェクトを渡すことで、画像データのパスを参照してそのデータをデータセットに取り入れるかどうかを判定できる。本来は画像データの拡張子などをチェックする用途のオプションだと思われるが、クラスの指定にも応用できる。
クラスごとにサブフォルダに分けていれば、データのパスを参照してクラスを把握できるため、目的のクラスかどうかをis_valid_fileオプションを介して判定してやればよい。callableなオブジェクトであればなんでもよいが、一例としてClassFilterクラスを以下に示す。
from pathlib import Path
class ClassFilter:
def __init__(self, OK_class_list=None):
self.OK_class_list = OK_class_list
def __call__(self, path):
p = Path(path)
if self.OK_class_list is None:
return True
dir_name = str(p.parent.name)
if dir_name in self.OK_class_list:
return True
else:
return False
OK_class_listは取り入れるクラス名のリストである。ClassFilterはpathの親フォルダの名前を参照し、OK_class_listに含まれていればTrueを、そうでなければFalseを返す。
上記のClassFilterをis_valid_fileオプションに渡せば特定のクラスのみを読み込めるが、ImageFolderはデフォルトでは空のサブフォルダを許さないためエラーが出る。エラーを回避するためには同じくImageFolderのallow_emptyオブションをTrueにする。これで完成。
注意
なお、特定のクラスのみを読み込んだ場合でも、クラスラベルの値はそのままである。つまり、3クラスのデータを全て読み込んだ場合、各クラスのクラスラベルは0,1,2と割り振られるが、1番目と3番目のクラスのみを読み込んだ場合のクラスラベルは0,2となる。線形分類器の出力次元数を設定する場合などに注意。