⚙️

TensorFlowのDataset.mapとDataset.shuffleでランダムな要素順を対応させる

2022/12/18に公開

はじめに

以前にTensorFlowのData APIでデータを効率的に流し込めると知り、Datasetを使い始めました。
ところがDataset.mapでどハマりし、今回4ヶ月越しに原因解明できたので、記事を書くことにしました。

問題のコード

モデルに入力するデータとラベルとして学習するための出力用データがタプルになったデータセットを作成します。入力と出力はそれぞれ足すと10になる整数です。

def func():
    """和が10になる値のタプルを返す関数"""
    n = random.randint(0, 10)
    return n, 10 - n

def map_func(data):
    """mapの中で任意のPythonを書くためのラッパー"""
    return tf.py_function(func=func, inp=[], Tout=[tf.int32, tf.int32])

def create_dataset():
    """入力用データセットと出力用データセットの作成"""
    dataset = tf.data.Dataset.range(5)
    dataset = dataset.shuffle(buffer_size=5)
    dataset = dataset.map(map_func)
    return dataset

上記は簡略化したコードなので、本来は必要ないtf.py_functionを使用しています。
inpfuncに渡す引数で、Toutは戻り値の型です。

Dataset.mapでは外部ライブラリなどを使おうとするとエラーが出たりするのですが、tf.py_functionを使うことで実行速度と引き換えに自由にPythonを書くことができるようになります。実際のコードではここで外部ライブラリを使用していました。

それでは、作成したデータセットの中身を出力してみましょう。

dataset = create_dataset()

# タプルのデータセットを入力と出力に分割
x = dataset.map(lambda x, y: x)
y = dataset.map(lambda x, y: y)

for x, y in zip(x, y):
    print(x, y)

データセットを分割してzipに入れるのは転置のためで、やっていることはzip(*hoge)と大体同じです。つまり、[データ数(5), データセット数(2)]の2次元配列を[データセット数(2), データ数(5)]に直しているだけです。

実行すると、このようになります。

# tf.Tensor(2, shape=(), dtype=int32) tf.Tensor(6, shape=(), dtype=int32)
# tf.Tensor(0, shape=(), dtype=int32) tf.Tensor(9, shape=(), dtype=int32)
# tf.Tensor(1, shape=(), dtype=int32) tf.Tensor(7, shape=(), dtype=int32)
# tf.Tensor(4, shape=(), dtype=int32) tf.Tensor(9, shape=(), dtype=int32)
# tf.Tensor(3, shape=(), dtype=int32) tf.Tensor(10, shape=(), dtype=int32)

あれ?入力と出力の対応関係が崩れてる…?
しかも出力に9が2つあるし。2 + 60 + 9も、足して10にはならないぞ、、

結果的には、dataset.map × randomが組み合わさったことでハマっていました。そこにdataset.shuffleも加わったことでカオスな状態になり、解決までかなりの時間を要してしまいました。

原因 1

直接的な原因はこれです。2つのdatasetを作成したことで、ランダム値を取得する際のseedが変わってしまったことが原因でした。

x = dataset.map(lambda x, y: x)  # 1つ目のデータセット
y = dataset.map(lambda x, y: y)  # 2つ目のデータセット

for x, y in zip(x, y):
    print(x, y)

解決方法 1

データセットを1つにまとめることで解決できます。

for data in dataset:
    print(data)

ただ今回は、実装の都合上この解決策は適用できませんでした(上記以外でも複数回データセットを呼び出す必要があった)。

なのでここからは、datasetを分けたい人のための解決方法を書いていきます。

解決方法 2

残念ながらここがスマートではないのですが、データごとにseedを固定することで解決できます。

def func(seed=None):
    random.seed(seed.numpy())  # seedを設定
    n = random.randint(0, 5)
    return n, 10 - n

def map_func(data):
    # funcのseedとしてdata(整数)を渡す
    return tf.py_function(func=func, inp=[data], Tout=[tf.int32, tf.int32])

上記では各データに固有のseed(tf.data.Dataset.rangeで生成された整数)を渡すことで、呼び出しごとに実行結果が変わるのを防いでいます。

対処療法的でややダサいですが、汎用的ではあります。

原因 2

上記の修正で解決したかと思いきや、実はまだ別の問題が残っています。
修正したコードの実行結果はこのようになります。

# tf.Tensor(1, shape=(), dtype=int32) tf.Tensor(10, shape=(), dtype=int32)
# tf.Tensor(1, shape=(), dtype=int32) tf.Tensor(9, shape=(), dtype=int32)
# tf.Tensor(0, shape=(), dtype=int32) tf.Tensor(7, shape=(), dtype=int32)
# tf.Tensor(1, shape=(), dtype=int32) tf.Tensor(9, shape=(), dtype=int32)
# tf.Tensor(3, shape=(), dtype=int32) tf.Tensor(9, shape=(), dtype=int32)

入力と出力の和は10になりませんが、順番を入れ替えることで10になっています。
完全にランダムになる問題は解決しましたが、対応関係が崩れている問題はまだ未解決のままですね。

これは、datasetの生成ごとに要素順がシャッフルされることが原因です。

解決方法 2(続き)

dataset.shufflereshuffle_each_iteration=Falseを渡すことで解決できます。

def create_dataset():
    dataset = tf.data.Dataset.range(5)
    # 呼び出しのたびにシャッフルしない(最初だけシャッフルする)
    dataset = dataset.shuffle(buffer_size=5, reshuffle_each_iteration=False)
    dataset = dataset.map(map_func)
    return dataset

これで1回目の呼び出しでのみ要素順がシャッフルされ、2回目以降は同じ要素順が保たれます。

それでは、もう1度実行してみましょう!

# tf.Tensor(1, shape=(), dtype=int32) tf.Tensor(9, shape=(), dtype=int32)
# tf.Tensor(1, shape=(), dtype=int32) tf.Tensor(9, shape=(), dtype=int32)
# tf.Tensor(0, shape=(), dtype=int32) tf.Tensor(10, shape=(), dtype=int32)
# tf.Tensor(1, shape=(), dtype=int32) tf.Tensor(9, shape=(), dtype=int32)
# tf.Tensor(3, shape=(), dtype=int32) tf.Tensor(7, shape=(), dtype=int32)

入力と出力の和が10になりました 🎉🎉🎉
ちゃんとシャッフルもされていますね。

おわりに

今回はDataset.map × Dataset.shuffle × randomと3つの要素が絡み合った問題だったので、解決まで本当に時間がかかりました(というか実装ではもっと複雑だった)。

実際の入出力データは数値ではなく画像で、今回の数値は画像のトリミング位置にあたります。見ているのは画像なので、そもそも原因1が発生していると気付くまでにかなりの時間がかかってしまいました。

もはやなぜ今さら解決できたのか分かりませんが、急に思い付いたので解決できました。

参考

GitHubで編集を提案

Discussion