Open2

TensorFlow v2.15.0 での破壊的変更への対応 (You can only use it as input to a Keras layer or a Keras operation)

PINTOPINTO

TensorFlow v2.15.0 から、Keras の Functional モデルに対して tf.xxxtf.math.xxx などの TensorFlow のプリミティブなオペレーションを含めると下記のエラーが発生してグラフの構築に失敗する。

ValueError: A KerasTensor cannot be used as input to a TensorFlow function.
A KerasTensor is a symbolic placeholder for a shape and dtype,
used when constructing Keras Functional models or Keras Functions.
You can only use it as input to a Keras layer or a Keras operation
(from the namespaces `keras.layers` and `keras.operations`).
You are likely doing something like:

x = Input(...)
tf_fn(x)  # Invalid.

What you should do instead is wrap `tf_fn` in a layer:

class MyLayer(Layer):
    def call(self, x):
        return tf_fn(x)

x = MyLayer()(x)

したがって、下記のようにすべての tf.xxxtf.keras.layers.Layer を継承して作成したクラスの call メソッドを経由してグラフを構築する必要が有る。

  • TensorFlow v2.14.0 までは動作する
    output = \
        tf.reshape(
            tensor=input_tensor,
            shape=new_shape,
            name=graph_node.name,
        )
    
  • TensorFlow v2.15.0 以降からは下記のように変更
    class KLayer(tf.keras.layers.Layer):
        def __init__(self):
            super(KLayer, self).__init__()
    
        def call(self, tf_fn, x, **kwargs):
            return tf_fn(x, **kwargs)
    
    output = \
        KLayer()(
            tf_fn=tf.reshape,
            x=input_tensor,
            shape=new_shape,
            name=graph_node.name,
        )
    
PINTOPINTO

なお、strided_slice のようなオペレーションは、MixedTensor形式を受け付けないためKerasがエラーを吐く。引数を **kwargs とするとエラーになる。回避策は、すべての引数を個別パラメータで列挙する。つまり、上記の Reshape のパターンも同じで、引数に定数と変数が混在しているパラメータを **kwargs で渡した場合にはKerasがエラーになる。つまり、地獄。

        # strided_slice
        # tf_layers_dict[graph_node_output.name]['tf_node'] = \
        #     tf.strided_slice(
        #         input_=input_tensor,
        #         begin=begin_,
        #         end=end_,
        #         strides=strides_,
        #         begin_mask=begin_mask_,
        #         end_mask=end_mask_,
        #         name=graph_node.name,
        #     )

        class Kstrided_slice(tf.keras.layers.Layer):
            """Keras wrapper class for tf.xxx
            """
            def __init__(self):
                super(Kstrided_slice, self).__init__()

            def call(
                self,
                input_,
                begin,
                end,
                strides,
                begin_mask,
                end_mask,
                name,
            ):
                return \
                    tf.strided_slice(
                        input_=input_,
                        begin=begin,
                        end=end,
                        strides=strides,
                        begin_mask=begin_mask,
                        end_mask=end_mask,
                        name=name,
                    )

        tf_layers_dict[graph_node_output.name]['tf_node'] = \
            Kstrided_slice()(
                tf_fn=tf.strided_slice,
                input_=input_tensor,
                begin=tf.convert_to_tensor(begin_),
                end=tf.convert_to_tensor(end_),
                strides=tf.convert_to_tensor(strides_),
                begin_mask=tf.convert_to_tensor(begin_mask_),
                end_mask=tf.convert_to_tensor(end_mask_),
                name=graph_node.name,
            )