😊

Keras と Streamlit で深層学習モデルの構築をいい感じにする

3 min read

Streamlit は、データの可視化に特化した Python のフロントエンドフレームワークです。機能を絞っている分、JavaScript などの知識が不要で、お手軽にUIを作成できます。

他のフロントエンドフレームワークと同様、 Streamlit にも自動リロード(ファイルの更新を監視して、自動的にリロードしてくれる機能)が実装されています。普通、下図のように、モデルの構築に Streamlit を使って UI を構築し、可視化するような使い方が多いと思うのですが、Streamlit で可視化しながらモデルを構築していくと結構たのしかったので紹介します。

Keras モデルの Streamlit による可視化

さて、具体的な方法ですが、Keras に用意されている Graphviz を使ってモデルを可視化する機能 model_to_dot を利用します。以下のように、Keras モデルを引数に与えると、 pydot.dot オブジェクトが帰ってくるので、文字列として保持しておきます。

graph = tf.keras.utils.model_to_dot(model)
dot = str(graph)

一方、 Streamlit には、Dot を可視化する機能があります。

st.graphviz_chart(dot)

これを組み合わせることで、 Streamlit 上に Keras モデルのグラフを表示できます。

Subclassing API への対応

Keras では、以下の3つの方法でモデルを定義することができます。

API 説明
Sequential API レイヤーのリストを与えることで、モデルを定義する
Functional API 入出力のテンソルを指定することで、モデルを定義する
Subclassing API モデルクラスを継承して、独自のモデルを定義する

実は model_to_dot 関数は、 Sequential API もしくは Functional API で定義されたモデルにしか対応しておらず、 Subclassing API で定義されたモデルは可視化できません。

そこで、以下のように、Subclassing API で定義したモデルの入出力を指定して、新たに Functional API でモデルを定義してあげます。

# Subclassing API で定義されたモデル
model = CustomModel()

# 入力テンソルを用意する
input = tf.keras.Input((224, 224, 3))
output = model.call(input)
# Functional API で可視化用のモデルを用意
viz_model = tf.keras.Model(input, output)


# モデルの可視化
graph = tf.keras.utils.model_to_dot(
    viz_model,
    show_shapes=True, # 各レイヤーの入出力の表示有無
    expand_nested=False, # ネストされたグラフを展開するか否かを指定
    dpi=72  # デフォルトだと見切れてしまうので少し小さくする
)
st.graphviz_chart(str(graph))

ちなみに、 演算は全て Keras レイヤーを使って記述する必要があります。

# NG(model_to_dot では可視化できない)
class CustomModel(tf.keras.Model):
    
    ...

    def call(input):
        x = ...
        x = x + input
        return x

# OK(model_to_dot で可視化できる)
class CustomModel(tf.keras.Model):
    
    ...

    def call(input):
        x = ...
        x = tf.keras.layers.Add()([x, input])
        return x

グラフの自動更新

Streamlit にはファイル更新をチェックして自動的にプログラムを再実行してくれる機能が備わっています。以下のように、起動時にコマンドライン引数で指定するか、UIのメニューから 「Run On Save」 を選択します。

$ streamlit run --server.runOnSave true <ファイル名>

コード全体

デモ用に作成したコードはこちらです。

https://github.com/ohtaman/keras-streamlit-demo
import tensorflow as tf
import streamlit as st


class CustomModel(tf.keras.Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.conv1 = tf.keras.models.Sequential((
            tf.keras.layers.Conv2D(3, 3, padding='same'),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.ReLU()
        ))
        self.conv2 = tf.keras.models.Sequential((
            tf.keras.layers.Conv2D(3, 3, padding='same'),
            tf.keras.layers.BatchNormalization(),
        ))
        self.activation = tf.keras.layers.Softmax()
        self.add = tf.keras.layers.Add()
    
    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.conv2(x)
        x = self.add((x, inputs))
        x = self.activation(x)
        return x


def main():
    model = CustomModel()

    input = tf.keras.Input((224, 224, 3))
    output = model.call(input)
    graph = tf.keras.utils.model_to_dot(
        tf.keras.Model(input, output),
        show_shapes=True,
        expand_nested=False,
        dpi=72
    )
    st.graphviz_chart(str(graph))


if __name__ == '__main__':
    main()