🪸

【PyTorch】DataLoader解説

2024/04/04に公開

今回はPyTorchのDataLoaderを解説します。

1. DataLoaderとは

DataLoaderはPyTorchの機械学習モデルにデータを供給する枠組みです。
バッチ処理やデータのシャッフル、並列でのロード等の機能を持ちます。

2. 定義

以下のように定義されます。

import torch

torch.utils.data.DataLoader(dataset, batch_size, num_workers, shuffle, drop_last)

Datasetを用意できれば、Dataloaderは呼び出すだけです。
それぞれの引数について解説します。

・DataLoaderの引数

  • dataset
    PyTorchのDatasetオブジェクト
  • batch_size
    バッチサイズ。一度に供給されるデータの数
  • num_workers
    データをロード及びモデルへ供給するサブプロセスの数。多いほどデータ転送が高速になるがメモリ使用量が増加する。デフォルトは0でその場合メインプロセスのみでデータのロードと供給を行う。CPU性能に合わせて調整。
  • shuffle
    Trueでデータをシャッフルしてモデルに供給
  • drop_last
    Trueで最後のバッチのデータ数がバッチサイズに満たない場合、そのバッチを使用しない
DataLoaderの引数(全て)

dataset: PyTorchのDatasetオブジェクト
batch_size: バッチサイズ。一度に供給されるデータの数
shuffle: データをシャッフルしてモデルに供給するかどうか
sampler: データセットからデータをサンプリングするための戦略を定義。shuffle=Trueの場合は使用不可。
batch_sampler: バッチごとのサンプリング戦略を定義。これが指定されている場合、batch_size、shuffle、sampler、drop_lastは無視される。
num_workers: データローディングに使用するサブプロセスの数。デフォルトは0。
collate_fn: データをバッチにまとめる際に追加する前処理。
pin_memory: データをCUDAメモリにロードするかどうか。デフォルトはFalse。
drop_last: Trueで最後のバッチのデータ数がバッチサイズに満たない場合、そのバッチを切り捨てる
timeout: データローディングのタイムアウト値(秒)。デフォルトは0でタイムアウト無効。
worker_init_fn: 各サブプロセスを開始する際に呼び出される関数。
prefetch_factor: 各プロセスが、データロード処理中に次に処理するデータを前もって読み込む数。デフォルトは2で、2バッチ分のデータを各プロセスは前もって読み込む。
persistent_workers: 各プロセスのイテレーション毎にプロセスを終了するかどうか。データロードの速度は向上するが、学習メモリ使用量やメモリガベージ増加などの要因を孕むためデフォルトはFalse。

3. 使い方

・DataLoader

import torch

# DataLoaderの定義
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=4, shuffle=True, drop_last=True)

・例
DataLoaderは、Datasetで指定したデータを特定のバッチサイズでモデルに供給します。

# モデルの学習
for epoch in range(4):
    model.train() # モデルを学習モードに変更
    for batch in train_loader: # DataLoaderからバッチサイズのデータを取得
        batch = to_device(batch, device) # データを計算を行うデバイス(GPU)に移動
        x, t = batch # バッチサイズの学習用データと教師データを取得
            
        optimizer.zero_grad() # モデルの勾配をリセット
        with amp.autocast(): # 混合浮動小数点を使用
            y = model(x) # モデルの順伝播
            loss = loss_func(y, t) # 損失を計算
        scaler.scale(loss).backward() # 混合浮動小数点による勾配の消失を防ぐために、GradScalerを使用して逆伝播。勾配を計算
        scaler.step(optimizer) # パラメータを更新
        scaler.update() # GradScalerを更新
        scheduler.step() # 学習率を更新。optimizerを紐づけられている
        train_loss += loss.item() # 全エポックの合計損失を計算
        
    train_loss /= len(train_loader) # 全エポックの平均損失を計算

DataLoaderの役割はデータと教師データをバッチサイズで供給することです。
DataLoaderはPyTorchにおけるモデル学習のパイプラインの中で、データの供給に関する部分を一手に担ってくれており、これによりモデルの学習を簡潔なコードで記述することができます

学習では他にも、勾配や損失の計算、パラメータの更新や学習率の調整などを行う必要があります。
そのプロセスの最初にDataLoaderは使用されているのです。

まとめ

今回はDataLoaderについて解説しました。
Datasetのデータをモデルに供給する役割を持ち、主にモデル学習時にパイプラインの一部として組み込まれます。

Discussion