😀

PyTorchのDataLoaderでBrokenPipeErrorを回避する方法メモ

2021/10/16に公開

PyTorchのDataLoaderは,マルチプロセスでデータロードが出来る仕組みが備わっている.Windowsでそれを利用しようとしたところ,こちらと同じようなエラーを吐いて動作しなかった.そこで色々調べて解決を行ったため,その方法をメモする.

  • 動作環境
    • OS: Windows10 Home(バージョン 2004)
    • Python: 3.7.3
    • PyTorch: 1.3.1

DataLoaderのマルチプロセス

公式のDocsから引用すると

A DataLoader uses single-process data loading by default.

Within a Python process, the Global Interpreter Lock (GIL) prevents true fully parallelizing Python code across threads. To avoid blocking computation code with data loading, PyTorch provides an easy switch to perform multi-process data loading by simply setting the argument num_workers to a positive integer.

とのこと.非常にざっくり言うと,DataLoaderクラスのnum_workersという変数の値を1以上にすれば,データの読み込みを並列化できる,ということである.

BrokenPipeError

そこで,num_workersを1以上の値に設定して動かしたところ,

BrokenPipeError: [Errno 32] Broken pipe

とエラーを吐いて動作しなかった.
PytorchのDatasetをDataLoaderで並列に読み込みたいときのエラー(Windows)を参考に,Datasetの定義を別のファイルで行っても同様のエラーが発生した.

解決法

こちらを参考にすると,どうやらWindowsでマルチプロセスを行う場合,if __name__ == "__main__"内でマルチプロセスを行う関数を実行しなければならないとのこと.

修正前

train.py
from torch.utils.data import DataLoader
from dataloader import MyDataset #作成したdataset

def train():
    dataset = MyDataset()
    train_loader = DataLoader(dataset, num_workers=2, shuffle=True,
                              batch_size=4,
                              pin_memory=True,
                              drop_last=True)

    for batch in train_loader:
        #do some process...

if __name__ == "__main__":
    train()

修正後

train.py
from torch.utils.data import DataLoader
from dataloader import MyDataset #作成したdataset

def train(train_loader):
    for batch in train_loader:
        #do some process...

if __name__ == "__main__":
    #dataset, DataLoaderを移動
    dataset = MyDataset()
    train_loader = DataLoader(dataset, num_workers=2, shuffle=True,
                              batch_size=4,
                              pin_memory=True,
                              drop_last=True)

    train(train_loader)

DataLoaderの場合,インスタンスの生成をif __name__ == "__main__"内で行っていれば,データの読み込み自体は別の関数内で実行してもマルチプロセスは動作した.

まとめ

Windows環境でDataLoaderの並列化するためのメモを記した.深層学習周りでは,Windowsだと動作しないまたは少し工夫をしないと行けないといった作業が多く大変である.なので,Windows周りで発生するエラーは定期的に記事にしてまとめていきたい所存

GitHubで編集を提案

Discussion