📑

streamlitで機械学習Webアプリを作る

2022/03/14に公開

やりたいこと

streamlitを使って以下の機能を備えたtraderlの強化学習Webアプリを作ります。

  1. データの取得
  2. Agentの作成
  3. モデルのトレーニング
  4. 結果の表示
  5. モデルのセーブ

具体的には上の機能を持つ一つのAppクラスを作りサイドバーで選択できるようにする。

完成イメージ

image

必要なパッケージをインストールする

git clone https://github.com/komo135/trade-rl.git
cd trade-rl
pip install .

pip install streamlit

Webアプリの実行

以下のコマンドを入力してWebアプリを実行する

streamlit run https://raw.githubusercontent.com/komo135/traderl-web-app/main/github_app.py [ARGUMENTS]

streamlitとは

データ分析や機械学習のWebアプリケーションを簡単に作ることができるWebアプリケーションフレームワーク

traderlとは

https://github.com/komo135/trade-rl

強化学習を使ってFXや株式の取引を学習することができるpythonパッケージ

コード

アプリクラスとサイドバー関数の2つを作ります。

Appクラス

ホーム画面

ホーム画面に表示させるマークダウン

    def home(self):
        md = """
        # Traderl Web Application
        This web app is intuitive to [traderl](https://github.com/komo135/trade-rl).

        # How to Execute
        1. select data
            * Click on "select data" on the sidebar to choose your data.
        2. create agent
            * Click "create agent" on the sidebar and select an agent name and arguments to create an agent.
        3. training
            * Click on "training" on the sidebar to train your model.
        4. show results
            * Click "show results" on the sidebar to review the training results.
        """
        #マークダウンを画面に表示させる
        st.markdown(md)

データの選択

    def select_data(self):
        file = None
        
        #データの選択
        select = st.selectbox("", ("forex", "stock", "url or path", "file upload"))
        col1, col2 = st.columns(2)
        # クリックするとデータが読みこまれる
        load_file = st.button("load file")

        if select == "forex":
            symbol = col1.selectbox("", ("AUDJPY", "AUDUSD", "EURCHF", "EURGBP", "EURJPY", "EURUSD",
                                         "GBPJPY", "GBPUSD", "USDCAD", "USDCHF", "USDJPY", "XAUUSD"))
            timeframe = col2.selectbox("", ("m15", "m30", "h1", "h4", "d1"))
            if load_file:
                self.df = data.get_forex_data(symbol, timeframe)
        elif select == "stock":
            symbol = col1.text_input("", help="enter a stock symbol name")
            if load_file:
                self.df = data.get_stock_data(symbol)
        elif select == "url or path":
            file = col1.text_input("", help="enter url or local file path")
        elif select == "file upload":
            file = col1.file_uploader("", "csv")

        if load_file and file:
            st.write(file)
            self.df = pd.read_csv(file)

        if load_file:
            st.write("Data selected")

    def check_data(self):
        f"""
        # Select Data
        """
        #データが存在しているかを確認する
        if isinstance(self.df, pd.DataFrame):
            st.write("Data already exists")
            # データ既に存在していて初期化するかの確認のボタン
            if st.button("change data"):
                st.warning("data and agent have been initialized")
                self.df = None
                self.agent = None

        #データがない場合新しくデータを選択する
        if not isinstance(self.df, pd.DataFrame):
            self.select_data()

エージェントの作成

    #エージェントを作成する
    def create_agent(self, agent_name, args):
        agent_dict = {"dqn": dqn.DQN, "qrdqn":qrdqn.QRDQN}
        self.agent = agent_dict[agent_name](**args)

    #エージェントの選択、引数の選択
    def agent_select(self):
        # データが存在しない場合、警告を出す
        if not isinstance(self.df, pd.DataFrame):
            st.warning("data does not exist.\n"
                       "please select data")
            return None
        #エージェントの選択
        agent_name = st.selectbox("", ("dqn", "qrdqn"), help="select agent")

        """
        # select Args
        """
        # 使用可能なtensorflowモデルの選択
        col1, col2 = st.columns(2)
        network = col1.selectbox("select network", (nn.available_network))
        network_level = col2.selectbox("select network level", (f"b{i}" for i in range(8)))
        network += "_" + network_level
        self.model_name = network
        #その他の引数の選択
        col1, col2, col3, col4 = st.columns(4)
        lr = float(col1.text_input("lr", "1e-4"))
        n = int(col2.text_input("n", "3"))
        risk = float(col3.text_input("risk", "0.01"))
        pip_scale = int(col4.text_input("pip scale", "25"))
        col1, col2 = st.columns(2)
        gamma = float(col1.text_input("gamma", "0.99"))
        use_device = col2.selectbox("use device", ("cpu", "gpu", "tpu"))
        train_spread = float(col1.text_input("train_spread", "0.2"))
        spread = int(col2.text_input("spread", "10"))

        kwargs = {"df": self.df, "model_name": network, "lr": lr, "pip_scale": pip_scale, "n": n,
                  "use_device": use_device, "gamma": gamma, "train_spread": train_spread,
                  "spread": spread, "risk": risk}
       
        #ボタンをクリックするとエージェントが作成される
        if st.button("create agent"):
            self.create_agent(agent_name, kwargs)
            st.write("Agent created")

モデルのトレーニング

    #エージェントが存在しているかを確認する
    def agent_train(self)
        #存在している場合、ボタンをクリックするとモデルがトレーニングされる
        if self.agent:
            if st.button("training"):
                self.agent.train()
        #ない場合、警告を出す
        else:
            st.warning("agent does not exist.\n"
                       "please create agent")

トレーニング結果の表示

    def show_result(self):
        #エージェントが存在しているかを確認する
        if self.agent:
            self.agent.plot_result(self.agent.best_w)
        else:
            st.warning("agent does not exist.\n"
                       "please create agent")

モデルのセーブ

    def model_save(self):
        # セーブするファイル名を入力してボタンをクリックしてモデルをセーブする
        if self.agent:
            save_name = st.text_input("save name", self.model_name)
            if st.button("model save"):
                self.agent.model.save(save_name)
                st.write("Model saved.")
        else:
            st.warning("agent does not exist.\n"
                       "please create agent")

初期化

    @staticmethod
    def clear_cache():
        if st.button("initialize"):
            st.experimental_memo.clear()

## サイドバーの作成
```python
def sidebar():
    return st.sidebar.radio("", ("Home", "select data", "create agent", "training",
                                 "show results", "save model", "initialize"))

コードの実行

appをst.session_stateに保存してロードする理由

  • サイドバーにある要素を選択する度に最初化からの実行になる為、データやエージェントが保存されない
  • st.session_stateこの変数はページがロードされるまで保持される

appクラスとサイドバー関数を分ける理由

  • サイドバー関数がappクラス内にあるとappクラスst.session_stateからロードするとサイドバーが表示されなくなるから
if __name__ == "__main__":
    st.set_page_config(layout="wide", )

    if "app" in st.session_state:
        app = st.session_state["app"]
    else:
        app = App()

    select = sidebar()

    if select == "Home":
        app.home()

    if select == "select data":
        app.check_data()
    elif select == "create agent":
        app.agent_select()
    elif select == "training":
        app.agent_train()
    elif select == "save model":
        app.model_save()
    elif select == "show results":
        app.show_result()

    st.session_state["app"] = app
    if select == "initialize":
        app.clear_cache()

github

https://github.com/komo135/traderl-web-app

Discussion