【PyTorch】DataLoader解説
今回はPyTorchのDataLoaderを解説します。
1. DataLoaderとは
DataLoaderはPyTorchの機械学習モデルにデータを供給する枠組みです。
バッチ処理やデータのシャッフル、並列でのロード等の機能を持ちます。
2. 定義
以下のように定義されます。
import torch
torch.utils.data.DataLoader(dataset, batch_size, num_workers, shuffle, drop_last)
Datasetを用意できれば、Dataloaderは呼び出すだけです。
それぞれの引数について解説します。
・DataLoaderの引数
- dataset
PyTorchのDatasetオブジェクト - batch_size
バッチサイズ。一度に供給されるデータの数 - num_workers
データをロード及びモデルへ供給するサブプロセスの数。多いほどデータ転送が高速になるがメモリ使用量が増加する。デフォルトは0でその場合メインプロセスのみでデータのロードと供給を行う。CPU性能に合わせて調整。 - shuffle
Trueでデータをシャッフルしてモデルに供給 - drop_last
Trueで最後のバッチのデータ数がバッチサイズに満たない場合、そのバッチを使用しない
DataLoaderの引数(全て)
dataset: PyTorchのDatasetオブジェクト
batch_size: バッチサイズ。一度に供給されるデータの数
shuffle: データをシャッフルしてモデルに供給するかどうか
sampler: データセットからデータをサンプリングするための戦略を定義。shuffle=Trueの場合は使用不可。
batch_sampler: バッチごとのサンプリング戦略を定義。これが指定されている場合、batch_size、shuffle、sampler、drop_lastは無視される。
num_workers: データローディングに使用するサブプロセスの数。デフォルトは0。
collate_fn: データをバッチにまとめる際に追加する前処理。
pin_memory: データをCUDAメモリにロードするかどうか。デフォルトはFalse。
drop_last: Trueで最後のバッチのデータ数がバッチサイズに満たない場合、そのバッチを切り捨てる
timeout: データローディングのタイムアウト値(秒)。デフォルトは0でタイムアウト無効。
worker_init_fn: 各サブプロセスを開始する際に呼び出される関数。
prefetch_factor: 各プロセスが、データロード処理中に次に処理するデータを前もって読み込む数。デフォルトは2で、2バッチ分のデータを各プロセスは前もって読み込む。
persistent_workers: 各プロセスのイテレーション毎にプロセスを終了するかどうか。データロードの速度は向上するが、学習メモリ使用量やメモリガベージ増加などの要因を孕むためデフォルトはFalse。
3. 使い方
・DataLoader
import torch
# DataLoaderの定義
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=4, shuffle=True, drop_last=True)
・例
DataLoaderは、Datasetで指定したデータを特定のバッチサイズでモデルに供給します。
# モデルの学習
for epoch in range(4):
model.train() # モデルを学習モードに変更
for batch in train_loader: # DataLoaderからバッチサイズのデータを取得
batch = to_device(batch, device) # データを計算を行うデバイス(GPU)に移動
x, t = batch # バッチサイズの学習用データと教師データを取得
optimizer.zero_grad() # モデルの勾配をリセット
with amp.autocast(): # 混合浮動小数点を使用
y = model(x) # モデルの順伝播
loss = loss_func(y, t) # 損失を計算
scaler.scale(loss).backward() # 混合浮動小数点による勾配の消失を防ぐために、GradScalerを使用して逆伝播。勾配を計算
scaler.step(optimizer) # パラメータを更新
scaler.update() # GradScalerを更新
scheduler.step() # 学習率を更新。optimizerを紐づけられている
train_loss += loss.item() # 全エポックの合計損失を計算
train_loss /= len(train_loader) # 全エポックの平均損失を計算
DataLoaderの役割はデータと教師データをバッチサイズで供給することです。
DataLoaderはPyTorchにおけるモデル学習のパイプラインの中で、データの供給に関する部分を一手に担ってくれており、これによりモデルの学習を簡潔なコードで記述することができます
学習では他にも、勾配や損失の計算、パラメータの更新や学習率の調整などを行う必要があります。
そのプロセスの最初にDataLoaderは使用されているのです。
まとめ
今回はDataLoaderについて解説しました。
Datasetのデータをモデルに供給する役割を持ち、主にモデル学習時にパイプラインの一部として組み込まれます。
Discussion