Open2

WebDataset について調べる

PlatPlat

deepghs/danbooru2024-sfw を例に使ってみる(PNGでデカいので若干注意。tar ファイルをひとつづつ落としてくるのだが、それがだいぶ遅い... 可能なら webp 版を使うといい)

https://huggingface.co/datasets/deepghs/danbooru2024-sfw

import webdataset as wds
from huggingface_hub import get_token
from torch.utils.data import DataLoader

hf_token = get_token()
# 画像のディレクトリ指定
url = "https://huggingface.co/datasets/deepghs/danbooru2024-sfw/resolve/main/images/{{0000..1023}}.tar"
# トークンつけて取得するように
url = f"pipe:curl -s -L {url} -H 'Authorization:Bearer {hf_token}'"

buffer_size = 1024
ds = (
    wds.WebDataset(
        url,
        shardshuffle=True, # シャッフルする
        seed=42,
    )
    .shuffle(buffer_size) # シャッフル時のバッファサイズ
    .decode("pil")
    # "data", "jpg", "webp" のフィールドをまとめて画像としてデコード
    .to_tuple("__key__", "png;jpg;webp") # (id, image) のタプルに
)

id, image = next(iter(ds)) # ds は IterableDataset なのでこれで最初のデータを取得
print(id)
image.show() # デスクトップ環境なら画像がプレビューされる