TensorflowのTFRecord & tf.data.Datasetを使うときのメモ
TFRecordからDatasetを作って使用する場合、exampleのパースにはtf.io.parse_example
を使うのがよい。
parse_exampleを使うことでbatch単位でパース処理がかけるようになり、実行速度が上がる。
dataset = dataset.batch().map(parse_function)
よくtutorialや記事で見かけるtf.io.parse_single_exmaple
はbatchされたデータに処理をかけれないので遅い。
dataset単位でかけれるtf.data.experimental.parse_example_dataset
もあるがparse_exampleのほうが推奨?
tfrecordのフォーマットで二次元arrayの特徴を扱う場合
tf.train.Int64List
と tf.train.FloatList
と tf.train.BytesList
のvalueには一次元配列しか渡せないので多次元配列は一度バイト列にserializeする必要がある。
書き込み
# xはnumpy配列
serialized_x = tf.io.serialize_tensor(x)
tf.train.BytesList(value=[serialized_x.numpy()])
読み込み
多次元配列はtf.io.parse_tensor
でデシリアライズする。
out_typeは適切に設定する必要がある。
parsed_element = tf.io.parse_single_example(example_proto, feature_description)
parsed_element['image'] = tf.io.parse_tensor(parsed_element['image'], out_type=tf.uint8)
多次元配列をデシリアライズするtf.io.parse_tensor
はtf.io.parse_exmaple
とセットで使えないぽい。
parse_tensorがバッチデータに対応していないため。
なのでparse_tensorを使う場合はparse_single_exampleでやる必要がある。
対応策として多次元配列をflatten&shapeを保存しておき、読み込みの際にreshapeする方法があるかも
ここらへんはなにかいい方法が他にあるんだろうか
tf.io.parse_tensor only accepts a single tensor and not a batch of tensors.
padded_batchについて
特定のtfのバージョンからbatch()でも自動でpaddingするようになっているが、入力形式によっては自分でpadding処理を書く必要がある。
以下サンプルコード
padded_ds = ds.padded_batch(
BATCH_SIZE,
padded_shapes={
'label': [], # Scalar elements, no padding.
'sequence_feature': [None], # Vector elements, padded to longest.
'seq_of_seqs_feature': [None, None], # Matrix elements, padded to longest
})
outputが複数ある場合は以下のようにpadded_shapesを({}, {})
みたいな感じにする
shapes = ({'input1': [], 'input2': [None], 'input3': []}, [])
dataset = dataset.map(lambda ex: labeler(ex))
dataset = dataset.shuffle(1000).repeat(2).padded_batch(batch_size,
padded_shapes=shapes)