💾

1ペタバイトのデータセットで機械学習する / WebDataset入門

深層学習をする上で、最も大切なマシンスペックを聞かれたら何と答えますか?
GPUのTensor性能、VRAM、GPUの数、CPU性能、メモリ、… 問題によって正解は異なりますね。

しかし、特に大規模なデータセットで機械学習する場合では、しばしばネットワーク帯域ストレージシステムのディスクI/Oによって制限されます。この記事ではそのような課題に対して、学習側でどのようにデータを扱うかを見ていきたいと思います。

1. この記事は?

こんにちは、TURING MLチームです。TURINGはEnd-to-Endな深層学習モデルでLv5完全自動運転車の開発を目指す会社です。

私たちは自動運転モデルを動かすため、可視域のカメラセンサによる画像で学習し、カメラ映像のみから車体の操作や経路選択、安全性の判断を行わせています。(実際の車を動かす事例はこちらの記事をご覧ください。)

そのため、機械学習のために大量の画像データが必要になってきます。TURINGでは2022年に500時間、2023年に50,000時間の公道上の走行データをカメラによって収集する計画を立てています。50,000時間、というとピンと来ないかと思いますが、仮に平均時速25kmだとすると、合計で125万kmになります。これは日本の道路総延長(=128万km) に匹敵する距離です。(もちろん、日本の全ての道路を走行するわけではなく、都市部を中心に同じ道をさまざまな角度・条件で撮影していくことになります。)

重要なことは、データが大量の動画という形で取得されるという点です。動画はエンコードされて圧縮されている状態でも~1GB/時間、複数カメラで機械学習用のテンソルに成形すると数十GB/時間程度にもなります。そのため、50,000時間の走行でおよそ1ペタバイト程度のデータとなります。

TURINGはこのような大規模な機械学習に向け、ストレージシステムの検討やI/Oパフォーマンスの測定、そして100基単位でのGPUによる並列分散学習のための技術検証を進めています。この記事では、開発段階の小規模なデータセットからペタバイトスケールの機械学習まで対応できるデータローダーの仕組みの一例として、PyTorchで学習するケースを紹介していきたいと思います。

2. 学習データをロードする5つのシナリオ

機械学習でデータセットを計算サーバに転送するには5つの方法があります。

(1) オンメモリ
(2) ローカルディスク
(3) Web/ファイルサーバー
(4) クラウドストレージ
(5) 高速分散ストレージシステム

全てのデータがメモリ上に収まるサイズのデータセットであれば、多くのケースで特別なデータハンドリングは必要ありません。一度データを読み出せば高速で処理が可能です。一方、画像分類等の大きな訓練用データセットは、しばしばメモリサイズを超過します。そのようなケースでは、ディスク上に配置されたファイルを逐次読み込む必要があります。

インターネット上に公開されているデータセットは、前もって全てダウンロードすればローカルディスクと同様に扱えますが、ネットワーク帯域によっては多くの時間が要求されます

また、Amazon S3など、クラウドストレージ上に保存されているデータを読み出す場合、ファイルシステムとしてマウントしたり、公式のS3 IO datapipesを使ってデータパイプラインを作成します。さらに大規模な環境では、LustreやGPFSなど高速な並列分散ファイルシステムが採用されることがありますが、このような場合でもやはりネットワーク帯域やファイルシステムのパフォーマンスに影響を受けます。

一方、どのようなストレージシステムを利用するにせよ、機械学習のデータセットへのアクセスは共通して下記のような特徴を持ちます。このようなデータアクセスは機械学習特有なもので、既存のファイルシステムとうまく適合しない可能性があります。

  • 均一にランダムなアクセスパターンを持つ
  • 多くの(ときには数十億の)ファイルで構成されている
  • 学習前にデータをシャッフルや前処理、データ拡張する必要がある

小さなデータセットで開発/テストをしてから、大きなデータセットにスケールアップさせていく場合、ストレージシステムを変更する状況がしばしば生じます。学習データのスケールにあわせ、システムごとにデータローダーの実装やパフォーマンスチューニングをするのは大変にコストがかかります。そのため、ファイル形式やロードのためのコードをできるだけ変更せずに対応したいわけですが、どうしたらよいでしょうか?

3. WebDatasetとは

PyTorchに対するWebDatasetライブラリは、このようなデータ読み込みの問題を解決し、(1)~(5)までの全てのシナリオに対してペタスケールまでの統一的で効率のよいアクセスを提供してくれます。

WebDataset is an ideal solution for training on petascale datasets kept on high performance distributed data stores like AIStore, AWS/S3, and Google Cloud.

WebDataset also is very useful for such smaller datasets, and it can easily be used for developing and testing on small datasets and then scaling up to large datasets by simply using more shards.

WebDatasetはBigDat2019で発表されたHigh Performance I/O For Large Scale Deep Learningで示された大規模な深層学習のためのデータセット機構とその実装です。WebDatasetは任意のストレージシステムにデータを数十~数百MBごとにシャーディング(分割)して配置し、シーケンシャルに読み込むことででストリーミングによるアクセスを可能にしています。

WebDatasetはPythonで書かれた外部依存性のない独立したライブラリとして開発されており、将来的にPyTorchのサブパッケージとして取り込まれるための提案がなされています(RCF 38419)。

同様のライブラリとしてTensorFlowのTFRecordがありますが、WebDatasetではPOSIX tarによるファイルベースを採用しており、シリアライズされた形式に変換する必要がないという特徴があります。

4. 基本的な使い方

4-1. インストール

$ pip install webdataset

または

$ pip install git+https://github.com/tmbdev/webdataset.git

4-2. データセット

ここでは文書画像のデータセットであるPubLayNetを用いて説明してきたと思います。PubLayNetは、ドキュメント画像の大規模なデータセットで、そのレイアウトには、境界ボックスと多角形のセグメンテーションの両方で注釈が付けられています。

Publaynetのサンプル画像
Publaynetの画像データ

まずデータセットの一部として、290MB程度のシャードファイルを適当な場所(ここでは/tmp)にダウンロードしておきます。

$ curl -L "http://storage.googleapis.com/nvdata-publaynet/publaynet-train-000000.tar" -o "/tmp/publaynet_000000.tar"

実態は普通のtarファイルで、中身は画像ファイル(png)メタデータ(json) のセットが985組アーカイブされています。

$ tar -tf /tmp/publaynet_000000.tar | head
PMC4991227_00003.json
PMC4991227_00003.png
PMC4537884_00002.json
PMC4537884_00002.png
PMC4323233_00003.json
PMC4323233_00003.png
PMC5429906_00004.json
PMC5429906_00004.png
PMC5592712_00002.json
PMC5592712_00002.png

$ tar -tf /tmp/publaynet_000000.tar | wc -l
1970

このように、WebDatasetでは特殊なデータ形式に変更することなくPOSIX tar形式でアーカイブされたファイルを読み出すことができます。データセットは標準のtarコマンドでも容易に作成することができます。

4-3. ローカルディスクからの読み出し

import torch
import webdataset as wds

url = "/tmp/publaynet_000000.tar"
# ローカルファイルパスをセットします.
dataset = wds.WebDataset(url)
# データパイプラインの定義(デコード、タプル化).
dataset = dataset.decode("rgb").to_tuple("png", "json")

webdataset.WebDatasetはPyTorch標準のIterableDatasetと同じインターフェースが実装されています。IterableDatasetなのでイテレータでデータを取り出すことができます。

print(isinstance(dataset, torch.utils.data.IterableDataset))  # True
# データを取得します.
image, json = next(iter(dataset))
print(image.shape, image.dtype, type(json))  # (794, 610, 3) float32 <class 'dict'>

データセットの前処理として、サンプリングされたデータに対し、任意の関数をmap()を使うことで適用できます。ここではjsonからcategory_idをラベルとして設定しておきます。

def preprocess(sample):
    image, json = sample
    try:
        label = json["annotations"][0]["category_id"]
    except Exception:
        label = 0
    return image, label
    
dataset = dataset.map(preprocess)

さらにデータ拡張として、画像データから、さらにランダムに256x256のイメージサイズに切り出す処理を入れてみます。compose()と任意のジェネレータを用いることで、データセット全体に処理することができます。

from random import randrange

def get_patches(source):
    for sample in source:
        image, label = sample
        # サンプリングされた画像のheight/widthを取得します.
        h, w = image.shape[:2]
        for _ in range(16):
            y, x = randrange(h - 256), randrange(w - 256)
            patch = image[y : y + 256, x : x + 256]
            yield (patch, label)

dataset = dataset.compose(get_patches)
dataset = dataset.shuffle(10000)  # バッファーサイズ=10000でシャッフルします.

データセットの準備としてはこれで終わりです。
最後にPyTorch標準のDataLoaderに渡して完了です。

dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, num_workers=4)

images, labels = next(iter(dataloader))
print(images.shape, labels.shape)  # torch.Size([64, 256, 256, 3]) torch.Size([64])

また、複数のファイルを読み出すには最初のURLをpublaynet_{000000...000009}.tarのように指定するだけでOKです。

import torch
import webdataset as wds

url = "/tmp/publaynet_{000000..000009}.tar"
dataset = wds.WebDataset(url)
# 以下同じ

4-4. Webサーバ

先ほどはtarファイルをローカルにダウンロードしてデータセットとしました。WebDatasetではWebサーバのURLを直接設定することが可能です。urlhttp://storage.googleapis.com/nvdata-publaynet/publaynet-train-{000000..000009}.tarに設定します。

import torch
import webdataset as wds

url = "http://storage.googleapis.com/nvdata-publaynet/publaynet-train-{000000..000009}.tar"
dataset = wds.WebDataset(url)
dataset = dataset.decode("rgb").to_tuple("png", "json")
dataset = dataset.map(preprocess)
dataset = dataset.compose(get_patches)
dataset = dataset.shuffle(10000)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, num_workers=4)

images, labels = next(iter(dataloader))
print(images.shape, labels.shape)  # torch.Size([64, 256, 256, 3]) torch.Size([64])

ローカルストレージの場合とURLを変えるだけで、他のコードを変更することなくWebサーバから連続的に訓練用のバッチを取得することができます。WebDatasetでは任意の前処理やデータ拡張を定義してデータパイプラインとしてストリーミングできるため、学習前に個別にファイルをダウンロードする必要はありません。

4-5. Amazon S3

最後にAmazon S3に置かれたファイルを取得するケースをみていきます。S3から学習データセットを利用する場合、3つの選択肢があります。

  • オブジェクトをバイトストリームとして直接ロードする
  • S3バケットをファイルシステムとしてマウントする
  • SageMakerパイプモードを(解析して)利用する

SageMakerパイプモードはAmazon SageMakerが提供するS3に保存されているデータをやり取りするための専用APIです。パイプモードを使用すると、データは専用のLinux FIFOパイプを介して高速にストリーミングされます。データバイナリを解析する必要があるため、ここでは前の二つの方法を見ていきたいと思います。

オブジェクトのストリーム[1]

Amazon S3の場合、ローカルストレージ/Webサーバと異なり、直接URLを指定することができません。そこでPythonのBoto3ライブラリを用い、S3のオブジェクトを直接(ストレージを介さずに)メモリのバイトストリームとして取り込みます。

import io
import re

import boto3

client = boto3.client(
    "s3",
    aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"],
    aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"]
)

def get_stream(path):
    stream = io.BytesIO()
    _, bucket, key, _ = re.split("s3://(.*?)/(.*)$", path)
    client.download_fileobj(bucket, key, byte_io)
    stream.seek(0)
    return stream

さらにこれをWebDataset.tariterators.url_openerにオーバーライドすることでwebdataset.WebDatasetでS3のURLを直接取り込むことが可能になります。

def url_opener(data, handler=reraise_exeption, **kwd):
    for sample in data:
        url = sample["url"]
        try:
            stream = get_stream(url)
            sample.update(stream=stream)
            yield sample
        except Exception as e:
            e.args = e.args + (url,)
            if handler(e):
                continue
            else:
                break

# url_openerをオーバーライドします.
wds.tariterators.url_opener = url_opener

urls = [f"s3://<path of dataset>/publaynet_{i:06d}.tar" for i in range(10)]
dataset = wds.WebDataset(urls)

ファイルシステムのマウント

Amazon S3はS3FsgoofysなどFUSEベースのライブラリでマウントすることができます。またここでは扱いませんが、WebDataset以外でも、公式のS3 Pytorchプラグインをを使用することで、IterableDatasetのようにS3からデータをストリームすることが可能です。

# 注: はじめにstart methodをspawnにする必要があります.
torch.multiprocessing.set_start_method("spawn")

import s3fs

fs = s3fs.S3FileSystem(
    key=os.environ["AWS_ACCESS_KEY_ID"],
    secret=os.environ["AWS_SECRET_ACCESS_KEY"]
)

def url_opener(data, handler=reraise_exeption, **kwd):
    for sample in data:
        url = sample["url"]
        try:
            stream = fs.open(url.replace("s3://", ""), mode="rb")
	        sample.update(stream=stream)
        except Exception as e:
            e.args = e.args + (url,)
	    if handler(e):
	        continue
	    else:
	        break

# url_openerをオーバーライドします.
wds.tarietators.url_opener = url_opener

urls = [f"s3://<path of dataset>/publaynet_{i:06d}.tar" for i in range(10)]
dataset = wds.WebDataset(urls)

4-6. データセットの作成

WebDatasetで扱うのは単なるtarファイルなので、通常はtarコマンドを使用するだけで作成可能です

$ tar --sort=name -cf dataset.tar dataset/

またはGoで実装されたtarpをインストールし、tarp createコマンドでファイルからtarアーカイブファイルを生成するのが特に簡単になります。

# tarpのインストール
$ go get -v github.com/tmbdev/tarp/tarp

また、既存のデータセットに対して、webdataset.TarWriterを用いることでPython上で作成することも可能です。また、自分でデコーダーを定義することで任意のファイル形式にも対応させることができます。

sink = wds.TarWriter("dest.tar")
for index, (input, output) in dataset:
    sink.write({
        "__key__": "sample%06d" % index,
        "input.png": input,
        "output.cls": output,
    })
sink.close()

5. パフォーマンス

最後に、WebDatasetでデータを取得する場合のパフォーマンスを計測してみたいと思います。データセットはPubLaynetを用いて、1エポックあたりの平均読み込み速度を計測します。対象としたのは10GB相当のシャードファイルで約34500個分のイメージデータセットです。

  • ローカルディスクにpng/jsonとして展開したものをPyTorch標準のDatasetで読み込み
  • ローカルディスクのtarファイルをWebDatasetで読み込み
  • Webサーバーからダウンロードし、PyTorch標準のDatasetで読み込み
  • WebサーバーからWebDatasetでストリーミングして読み込み
  • Amazon S3からWebDatasetでバイトストリームで読み込み
  • Amazon S3からWebDatasetでFUSEマウントで読み込み
データ取得方法 1epochあたりのロード時間
ローカルディスク + Dataset 237 sec
ローカルディスク + WebDataset 279 sec
-- --
(Webサーバーからの直接ダウンロード時間) (921 sec)
Webサーバーダウンロード + Dataset 1155 sec
Webサーバー + WebDataset 925 sec
-- --
(Amazon S3からの直接ダウンロード時間) (1053 sec)
Amazon S3 + WebDataset (バイトストリーム) 1049 sec
Amazon S3 + WebDataset (FUSEマウント) 2363 sec

ローカルディスクで展開済のデータではPyTorch標準のDatasetが上回っていますが、Webサーバー・Aamzon S3からストリーミングする場合ではファイルを直接ダウンロードするのと同等の時間で(つまり学習中にネットワーク帯域をフルに使って)データセットをロードすることができるという結果になりました。シンプルな実装でさまざまなデータローディングに対応できる、スケールアウトも容易な点は大きなメリットかと思います。

6. おわりに、そして超大規模学習に向けて

今回はペタバイトスケールの機械学習にも対応可能なデータローダーの検証としてWebDatasetの使い方やパフォーマンスを紹介しました。あくまで機械学習の実装視点からのものですので、実際の大規模学習ではストレージシステムネットワーク、並列分散で計算させるGPUクラスタなど、さまざまな要素が必要になってきます。

そして、TURINGではこのような大規模な機械学習モデルをつくるエンジニアを募集しています。

- 真に大規模な深層学習モデルの設計・実験・評価
- ペタバイトスケールのデータストレージシステム/クラウドインフラ構築・管理
- 走行データ収集のためのiOS/Androidアプリ開発・運用

詳しくはTURINGの採用ページをご覧ください。一緒に完全な車をつくりませんか? ご不明な点があればTURING MLチーム 担当者 (Yu Yamaguchi)までお気軽にDMしてください。

https://www.wantedly.com/projects/1024347
https://www.wantedly.com/companies/turing-motors/projects

脚注
  1. Training in PyTorch from Amazon S3 ↩︎

Tech Blog - Turing

Discussion