Open2
TensorFlow v2.15.0 での破壊的変更への対応 (You can only use it as input to a Keras layer or a Keras operation)
TensorFlow v2.15.0 から、Keras の Functional モデルに対して tf.xxx
や tf.math.xxx
などの TensorFlow のプリミティブなオペレーションを含めると下記のエラーが発生してグラフの構築に失敗する。
ValueError: A KerasTensor cannot be used as input to a TensorFlow function.
A KerasTensor is a symbolic placeholder for a shape and dtype,
used when constructing Keras Functional models or Keras Functions.
You can only use it as input to a Keras layer or a Keras operation
(from the namespaces `keras.layers` and `keras.operations`).
You are likely doing something like:
x = Input(...)
tf_fn(x) # Invalid.
What you should do instead is wrap `tf_fn` in a layer:
class MyLayer(Layer):
def call(self, x):
return tf_fn(x)
x = MyLayer()(x)
したがって、下記のようにすべての tf.xxx
を tf.keras.layers.Layer
を継承して作成したクラスの call
メソッドを経由してグラフを構築する必要が有る。
- TensorFlow v2.14.0 までは動作する
output = \ tf.reshape( tensor=input_tensor, shape=new_shape, name=graph_node.name, )
- TensorFlow v2.15.0 以降からは下記のように変更
class KLayer(tf.keras.layers.Layer): def __init__(self): super(KLayer, self).__init__() def call(self, tf_fn, x, **kwargs): return tf_fn(x, **kwargs) output = \ KLayer()( tf_fn=tf.reshape, x=input_tensor, shape=new_shape, name=graph_node.name, )
なお、strided_slice
のようなオペレーションは、MixedTensor形式を受け付けないためKerasがエラーを吐く。引数を **kwargs
とするとエラーになる。回避策は、すべての引数を個別パラメータで列挙する。つまり、上記の Reshape
のパターンも同じで、引数に定数と変数が混在しているパラメータを **kwargs
で渡した場合にはKerasがエラーになる。つまり、地獄。
# strided_slice
# tf_layers_dict[graph_node_output.name]['tf_node'] = \
# tf.strided_slice(
# input_=input_tensor,
# begin=begin_,
# end=end_,
# strides=strides_,
# begin_mask=begin_mask_,
# end_mask=end_mask_,
# name=graph_node.name,
# )
class Kstrided_slice(tf.keras.layers.Layer):
"""Keras wrapper class for tf.xxx
"""
def __init__(self):
super(Kstrided_slice, self).__init__()
def call(
self,
input_,
begin,
end,
strides,
begin_mask,
end_mask,
name,
):
return \
tf.strided_slice(
input_=input_,
begin=begin,
end=end,
strides=strides,
begin_mask=begin_mask,
end_mask=end_mask,
name=name,
)
tf_layers_dict[graph_node_output.name]['tf_node'] = \
Kstrided_slice()(
tf_fn=tf.strided_slice,
input_=input_tensor,
begin=tf.convert_to_tensor(begin_),
end=tf.convert_to_tensor(end_),
strides=tf.convert_to_tensor(strides_),
begin_mask=tf.convert_to_tensor(begin_mask_),
end_mask=tf.convert_to_tensor(end_mask_),
name=graph_node.name,
)