Open5

TensorflowのTFRecord & tf.data.Datasetを使うときのメモ

ymickyymicky

TFRecordからDatasetを作って使用する場合、exampleのパースにはtf.io.parse_exampleを使うのがよい。
parse_exampleを使うことでbatch単位でパース処理がかけるようになり、実行速度が上がる。

dataset = dataset.batch().map(parse_function)
ymickyymicky

tfrecordのフォーマットで二次元arrayの特徴を扱う場合

tf.train.Int64Listtf.train.FloatListtf.train.BytesListのvalueには一次元配列しか渡せないので多次元配列は一度バイト列にserializeする必要がある。

書き込み

# xはnumpy配列
serialized_x = tf.io.serialize_tensor(x)
tf.train.BytesList(value=[serialized_x.numpy()])

読み込み
多次元配列はtf.io.parse_tensorでデシリアライズする。
out_typeは適切に設定する必要がある。

parsed_element = tf.io.parse_single_example(example_proto, feature_description)
parsed_element['image'] = tf.io.parse_tensor(parsed_element['image'], out_type=tf.uint8)

参考:Python: TFRecord フォーマットについて

ymickyymicky

多次元配列をデシリアライズするtf.io.parse_tensortf.io.parse_exmapleとセットで使えないぽい。
parse_tensorがバッチデータに対応していないため。
なのでparse_tensorを使う場合はparse_single_exampleでやる必要がある。
対応策として多次元配列をflatten&shapeを保存しておき、読み込みの際にreshapeする方法があるかも
ここらへんはなにかいい方法が他にあるんだろうか

https://github.com/tensorflow/tensorflow/issues/43706

tf.io.parse_tensor only accepts a single tensor and not a batch of tensors.

ymickyymicky

padded_batchについて

特定のtfのバージョンからbatch()でも自動でpaddingするようになっているが、入力形式によっては自分でpadding処理を書く必要がある。
以下サンプルコード

padded_ds = ds.padded_batch(
    BATCH_SIZE,
    padded_shapes={
        'label': [],                          # Scalar elements, no padding.
        'sequence_feature': [None],           # Vector elements, padded to longest.
        'seq_of_seqs_feature': [None, None],  # Matrix elements, padded to longest
    })      

https://stackoverflow.com/questions/49840100/tf-data-dataset-padded-batch-pad-differently-each-feature

outputが複数ある場合は以下のようにpadded_shapesを({}, {}) みたいな感じにする

shapes = ({'input1': [], 'input2': [None], 'input3': []}, [])
dataset = dataset.map(lambda ex: labeler(ex))
dataset = dataset.shuffle(1000).repeat(2).padded_batch(batch_size,
                                                       padded_shapes=shapes)

https://stackoverflow.com/questions/58372267/how-to-read-data-of-multiple-input-model-using-tf-data-textlinedataset