Open1
TensorFlow2の書き方
tensorflowで深層学習アーキテクチャを構築するときはtf.kerasを使うと直感的で良い。
しかし、複雑なアーキテクチャを作るときはtf.kerasだけだと対処できない問題も出てくる。
そういうときに「カスタムレイヤー」を使う。
構造
- __init__ , 入力に依存しないすべての初期化を行う
- build, 入力の shape を知った上で、残りの初期化を行う
- call, フォワード計算を行う
サンプル
class MyModel(Model):
# コンストラクタの作成
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = Conv2D(32, 3, activation='relu')
self.flatten = Flatten()
self.d1 = Dense(128, activation='relu')
self.d2 = Dense(10, activation='softmax')
# forward処理
def call(self, x):
x = self.conv1(x)
x = self.flatten(x)
x = self.d1(x)
return self.d2(x)
# モデルのインスタンスを作成
model = MyModel()
参照