🔖

RAG機能付きチャットボットを作ろう-4_チャット履歴保存

2024/11/27に公開

TL;DR

前回の記事で、OpenAIのクライアントの作成、プロンプトの作成、そして回答の取得までを行いました。本稿では

  • 情報の保存:streamlitのsession_stateを使って、プロンプトや回答を保存します。

実装イメージ

streamlitにおける情報の保存

情報の保存

streamlitでプロンプトの入力および回答の出力を行いました。しかしこのままでは、プロンプトを入力しても、回答を出力しても、次のプロンプトを入力すると、前のプロンプトや回答が消えてしまいます。
そこで、streamlitのsession_stateを使って、プロンプトや回答を保存します。

session_stateの使い方

streamlitではsession_stateという機能を使って様々な情報を保存できます。
str、intなどだけでなく、classや関数なども保存できます。dictのようにst.session_state['key']のような形で保存した情報を取得できます。
ただし、最終的にjson形式で保存することを考えると、保存する情報はstrやintなどの基本的な型にしておく、もしくは保存する範囲を限定するのがおすすめです。

実装コード

主な変更点は

  • st.session_stateの初期化
  • st.session_stateのデータを表示する関数の作成
  • プロンプトと回答をst.session_stateに保存する
  • st.rerun()による表示の更新

コード

import streamlit as st
from openai import OpenAI
from dotenv import load_dotenv
import os

load_dotenv(".env")

# OpenAIのAPIクライアントを初期化
# APIキーを環境変数から取得
client = OpenAI(
    api_key=os.environ['OPENAI_API_KEY']
    )

# プロンプトを入力すると、チャットボットが返答を返す関数を定義
# 入力はOpenAIのAPIクライアントとプロンプト
def default_chat(client, prompt):
    response = client.chat.completions.create(
        model="gpt-4o-mini", # 好きなモデルを選択
        messages=[
            {"role": "system", "content": "You are AI assistant."},
            {"role": "user", "content": prompt}
        ]
    )
    return response.choices[0].message.content

# streamlitのsession_stateにチャット履歴を保存する
# もしチャット履歴がなければ、空のリストを作成
if 'chat_history' not in st.session_state:
    st.session_state.chat_history = []

# チャット履歴を表示する関数
def display_chat_history():
    for chat in st.session_state.chat_history:
        if chat["role"] == "user":
            st.write("ユーザー: " + chat["content"])
        else:
            st.write("チャットボット: " + chat["content"])


st.title('RAG機能付きチャットボットを作ろう')
st.write('streamlitを使ったUIの作成')

# チャット履歴を表示
display_chat_history()

prompt = st.text_area('プロンプト入力欄', )

button1, button2 = st.columns(2)
if button1.button('チャット'):
    chat_response = default_chat(client, prompt)
    # チャット履歴に追加
    # ユーザーの入力を追加、roleはuser
    st.session_state.chat_history.append({"role": "user", "content": prompt})
    # チャットボットの返答を追加、roleはsystem
    st.session_state.chat_history.append({"role": "system", "content": chat_response})
    st.rerun()
if button2.button('RAG'):
    st.write('RAGボタンがクリックされました')

上記をmain.pyとして保存し、streamlitを起動します。

streamlit run main.py

以下のような画面が表示されれば成功です。

alt text

解説

st.session_stateの初期化

以下のコードで、st.session_statechat_historyというキーが存在しない場合、空のリストを作成しています。

if 'chat_history' not in st.session_state:
    st.session_state.chat_history = []

チャット履歴の表示

以下のコードで、chat_historyに保存された情報を表示しています。
st.session_stateをdictと同じように使っています。

def display_chat_history():
    for chat in st.session_state.chat_history:
        if chat["role"] == "user":
            st.write("ユーザー: " + chat["content"])
        else:
            st.write("チャットボット: " + chat["content"])

プロンプトと回答の保存

以下のコードで、プロンプトと回答をst.streamlit.chat_historyに保存しています。
また、表示の更新のためにst.rerun()を行っています。

if button1.button('チャット'):
    chat_response = default_chat(client, prompt)
    # チャット履歴に追加
    # ユーザーの入力を追加、roleはuser
    st.session_state.chat_history.append({"role": "user", "content": prompt})
    # チャットボットの返答を追加、roleはsystem
    st.session_state.chat_history.append({"role": "system", "content": chat_response})
    st.rerun()

リンク

Discussion