ファイルの変更を検知できるキャッシュ機構つきDatasetクラスの実装
概要
PyTorchを利用して深層学習モデルの訓練を行う上で、torch.utils.data.Dataset
クラスは欠かせない。
特に、データセット特有の前処理が存在する場合、それを事前に行っておくか、torch.utils.data.Dataset
を継承した自前データセットクラスの__init__()
で行うことが多い(自分の経験上)。
ここで、前処理を毎回行っても時間の無駄なので、キャッシュできるところ保存しておいて、次の学習に使いまわしたいという気持ちが発生する。
ところが、雑にキャッシュ処理を実装すると、例えばデータセットの内容が変更されてしまっているのにそれに気づかずに保存しておいた前のバージョンのデータセットを使用してしまったり、前処理自体の変更(前処理スクリプトの変更) に気づかなかったりということがありうる。
そこで、md5を用いてデータセットの中身から計算したハッシュ値と、前処理スクリプト自体のハッシュ値を計算し、ハッシュ値が同じ時のみ、保存しておいたキャッシュを利用するというスクリプトを書いた。
md5はハッシュ関数の一つである。
暗号学的ハッシュ関数としての強度は不足しているが、今回のように中身の変更を検知するような用途であれば問題なく利用できると考えられる。
あと、Pythonからでも簡単に利用することができて便利。
ファイルの変更日時を用いてファイルの変更を検知する方法も考えられたが、エディタでファイルの保存をしてしまうとファイルの変更日時が変わってしまい、中身は変わっていないのにキャッシュのハッシュ値が変わってしまう、という問題があった。
エディタで諸々のファイルを開くことはよくあるので、ファイルの保存をしないように意識しなければいけないのは面倒だと思い、md5でファイルの同一性を確認することにした。
実装
早速以下に実装を示す。
import torch
import hashlib
from pathlib import Path
from typing import Union
class MyDataset(torch.utils.data.Dataset):
def __init__(
self,
data_path: Union[Path, str],
):
with Path(data_path).open("rb") as f: # calc md5 hash of the dataset
data_hash = hashlib.md5(f.read()).hexdigest()
with Path(__file__).open("rb") as f: # calc md5 hash of this script itself
script_hash = hashlib.md5(f.read()).hexdigest()
md5 = hashlib.md5(f"{data_hash}-{script_hash}".encode("utf-8")).hexdigest()
cache_dir = Path.home() / ".cache/my-dataset/dataset-name"
cache_path = cache_dir / f"{md5}.pt"
if cache_path.exists():
self.dataset = torch.load(cache_path)
else:
with Path(data_path).open() as f:
# preprocess()は前処理用の適当な関数
self.dataset = preprocess(f.readlines())
cache_path.parent.mkdir(exist_ok=True, parents=True)
torch.save(self.dataset, cache_path)
def __len__(self):
return len(self.dataset)
def __getitem__(self, key: Union[int, slice]):
return torch.LongTensor(self.dataset[key])
流れを説明すると
- 生のデータセット自体のファイルの、中身のバイト列を読み込み、md5ハッシュ値を計算する
- このスクリプト自体のファイルの、中身のバイト列を読み込み、md5ハッシュ値を計算する
- md5ハッシュ値を結合してもう一回md5ハッシュ値を計算する
- md5ハッシュ値をファイル名に持つキャッシュが事前に指定したキャッシュ用ディレクトリに存在していればそれを読み込み、存在しなければ前処理を行なってできたデータセットオブジェクトを保存する
キャッシュ保存用のディレクトリとして~/.cache
を利用している。これは、huggingface/transformersなども同様なので、とりあえずこうしておいた。自分の好きなディレクトリに置くようにして欲しい。
Pythonスクリプトから、そのPythonスクリプト自体のパスを取得するのは__file__
変数へのアクセスのみで実現できる。
他に前処理スクリプトがある場合はハッシュの計算に含める必要があるが、とりあえず全部このファイルにまとめるとミスらなくて便利だと思う。
想定するデータはテキストとそのID列(なのでtorch.LongTensor
を使っている)だが、適当に書き換えればどうにかなるはず。
キャッシュ(というかデータ自体)の保存はtorch.save
が便利なのでこれを利用した。
自前でopen()
してpickle
するなどしてもいいが、この方法だとワンライナーで書けて便利。
huggingface/datasetsとか使うとここらへんよしなにやってくれる気もするが、とりあえず自前でやれて損はないのでやってみた。
参考になれば幸い。
余談 1
「エディタでファイルの保存をしてしまうとファイルの変更日時が変わってしまい、中身は変わっていないのにキャッシュのハッシュ値が変わってしまう」と書いた。
ただ、「前処理スクリプトを強い気持ちで開かない」 or 「学習は流しっぱなしにするのでそもそも開かない」という人は、ファイルの変更日時をもとにキャッシュIDを計算してもよいと思う。
その場合、ファイル全体を読み込まなくても良いので、時間・メモリ効率ともに非常に良くなる。
一応その場合の処理も書いたので気になったらチェックしてみて欲しい。
ファイルの変更日時取得には、便利なのでpathlib.Path
クラスを使っている。
import torch
import hashlib
from pathlib import Path
from typing import Union
class MyDataset(torch.utils.data.Dataset):
def __init__(
self,
data_path: Union[Path, str],
):
data_hash = Path(data_path).stat().st_mtime_ns
script_hash = Path(__file__).stat().st_mtime_ns
md5 = hashlib.md5(f"{data_hash}-{script_hash}".encode("utf-8")).hexdigest()
cache_dir = Path.home() / ".cache/my-dataset/dataset-name"
cache_path = cache_dir / f"{md5}.pt"
if cache_path.exists():
self.dataset = torch.load(cache_path)
else:
with Path(data_path).open() as f:
self.dataset = preprocess(f.readlines())
cache_path.parent.mkdir(exist_ok=True, parents=True)
torch.save(self.dataset, cache_path)
def __len__(self):
return len(self.dataset)
def __getitem__(self, key: Union[int, slice]):
return torch.LongTensor(self.dataset[key])
余談 2
実は、現在自分がファイルのキャッシュ用ハッシュ値計算に使用しているものは、md5とファイルの変更日時の両方を用いたものである。
データセットの方はファイルを変更することもそんなにないかと思い、データセットの方のみファイルの変更日時で、前処理スクリプトはmd5で同一性を確認している。
データセットを読み込む時間が短縮できてよさ。
import torch
import hashlib
from pathlib import Path
from typing import Union
class MyDataset(torch.utils.data.Dataset):
def __init__(
self,
data_path: Union[Path, str],
):
data_hash = Path(data_path).stat().st_mtime_ns
with Path(__file__).open("rb") as f:
script_hash = hashlib.md5(f.read()).hexdigest()
md5 = hashlib.md5(f"{data_hash}-{script_hash}".encode("utf-8")).hexdigest()
cache_dir = Path.home() / ".cache/my-dataset/dataset-name"
cache_path = cache_dir / f"{md5}.pt"
if cache_path.exists():
self.dataset = torch.load(cache_path)
else:
with Path(data_path).open() as f:
self.dataset = preprocess(f.readlines())
cache_path.parent.mkdir(exist_ok=True, parents=True)
torch.save(self.dataset, cache_path)
def __len__(self):
return len(self.dataset)
def __getitem__(self, key: Union[int, slice]):
return torch.LongTensor(self.dataset[key])
Discussion