🔖

RAG機能付きチャットボットを作ろう-8_回答をStructured Outputで構造化する

2024/12/14に公開

TL;DR

前回の記事では、promptに質問文を入れて、検索結果を含めた回答を生成AIが返す関数を作成しました。本稿では

  • 生成AIの回答を構造化する
    を行います。

実装

主な変更点

主な変更点は以下です。

  • pydanticを使って、構造化出力を定義
  • gpt-4o-miniを使って、回答を構造化出力に変換
  • 構造化した回答をdict形式にし、その後pandasのDataFrameに変換
  • streamlitのtableを使って、DataFrameを表示
  • Dataframeをcsvファイルに保存

コード

import streamlit as st
from openai import OpenAI
from dotenv import load_dotenv
import os
import chromadb
from langchain_chroma import Chroma
import openai
from pydantic import BaseModel, Field
from typing import List
from langchain_openai import OpenAIEmbeddings
from langchain_community.document_loaders import PyMuPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
import glob
from langchain_community.document_loaders import UnstructuredHTMLLoader
import pandas as pd

# .envファイルから環境変数を読み込む
load_dotenv(".env")
zotero_path = os.environ['ZOTERO_PATH']

# OpenAIのAPIクライアントを初期化

client = OpenAI(
    api_key=os.environ['OPENAI_API_KEY']
    )


# text splitterの定義
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=512,
    chunk_overlap=100,
    length_function=len,
    separators=["\n", " ", ".", ",", ";", ":", "(", ")", "[", "]", "{", "}", "<", ">", '"', "'", "、", "。", ",", ";", ":", "(", ")", "【", "】", "「", "」", "『", "』", "〈", "〉", "《", "》", "“", "”"],
    is_separator_regex=False,
)


# ベクトルDBのためのEmbeddingsの定義
embeddings = OpenAIEmbeddings(model="text-embedding-3-large")

# Chromaの初期化
vector_store = Chroma(
    collection_name="example_collection",
    embedding_function=embeddings,
    persist_directory="./chroma_langchain_db",  # Where to save data locally, remove if not necessary
)


def pdf_to_vector():
    # 特定フォルダのpdfのリスト化
    path = zotero_path
    # pdf_list_stored.txt を読み込んで、pdf_list に格納
    pdf_list = []
    if os.path.exists("pdf_list_stored.txt"):
        with open("pdf_list_stored.txt", "r") as f:
            for line in f:
                pdf_list.append(line.strip())
    else:
        # pdf_list_stored.txt が存在しない場合は、ファイルを作成
        with open("pdf_list_stored.txt", "w") as f:
            pass


    # globを使ってファイル名のリストを取得
    # サブフォルダも含める場合は、"**/*.pdf"とする

    pdf_list_current = glob.glob(path + "/**/*.pdf", recursive=True)
    pdf_list_new = list(set(pdf_list_current) - set(pdf_list))


    docs = []
    for pdf in pdf_list_new:
        loader = PyMuPDFLoader(pdf)
        text = loader.load_and_split(text_splitter)
        docs.append(text)
        # Vector storeに保存
        vector_store.add_documents(text)
        pdf_list.append(pdf)

    # pdf_list を pdf_list_stored.txt に保存
    with open("pdf_list_stored.txt", "w") as f:
        for pdf in pdf_list:
            f.write(pdf + "\n")
    

    # htmlファイルの場合
    html_list = []
    if os.path.exists("html_list_stored.txt"):
        with open("html_list_stored.txt", "r") as f:
            for line in f:
                html_list.append(line.strip())
    else:
        with open("html_list_stored.txt", "w") as f:
            pass

    html_list_current = glob.glob(path + "/**/*.html", recursive=True)
    html_list_new = list(set(html_list_current) - set(html_list))

    for html in html_list_new:
        loader = UnstructuredHTMLLoader(html)
        text = loader.load_and_split(text_splitter)
        docs.append(text)
        vector_store.add_documents(text)
        html_list.append(html)
    
    with open("html_list_stored.txt", "w") as f:
        for html in html_list:
            f.write(html + "\n")


## ベクトル検索のretireverの定義
retriever = vector_store.as_retriever(
    search_type="mmr", search_kwargs={"k": 1, "fetch_k": 3}
)


class structured_RAG_output(BaseModel):
    query: str = Field (..., description="The query from user")
    source: str = Field (..., description="The source file or URL of the document")
    answer: str = Field (..., description="The answer to the question")
    summary: str = Field (..., description="The summary of the answer")

class list_of_output(BaseModel):
    output: List[structured_RAG_output]



# プロンプトを入力すると、チャットボットが返答を返す関数を定義
# 入力はOpenAIのAPIクライアントとプロンプト
def default_chat(client, prompt):
    response = client.chat.completions.create(
        model="gpt-4o-mini", # 好きなモデルを選択
        messages=[
            {"role": "system", "content": "You are AI assistant."},
            {"role": "user", "content": prompt}
        ]
    )
    return response.choices[0].message.content


# プロンプトとRetrieverの結果を使って、回答を返す関数を定義
def chat_with_retriever(client, retriever, prompt):
    retriever_response = retriever.invoke(prompt)
    RAG_response = client.beta.chat.completions.parse(
        model="gpt-4o-mini", # 好きなモデルを選択
        messages=[
            {"role": "system", "content": """
             You are AI assistant.
             User will give you documents and question.
             You need to find the answer in the documents.
             You should NOT answer besides the documents.
             If you can not find the answer, you can ask the user for more information.
             """},
            {"role": "user", "content": f"""
             You are given document as below.
             {retriever_response}
             """},
            {"role": "assistant", "content": "I will find the answer in the document."},
            {"role": "user", "content": prompt},
        ],
        response_format=list_of_output,
        temperature=0.0,
    )
    lists = RAG_response.choices[0].message.parsed
    return lists

# streamlitのsession_stateにチャット履歴を保存する
# もしチャット履歴がなければ、空のリストを作成
if 'chat_history' not in st.session_state:
    st.session_state.chat_history = []

# チャット履歴を表示する関数
def display_chat_history():
    for chat in st.session_state.chat_history:
        if chat["role"] == "user":
            st.markdown(
                # 背景をグレーにして、角を丸くする
                f'<div style="background-color: #f0f0f0; border-radius: 10px; padding: 10px;">'
                f"ユーザー: {chat['content']}"
                '</div>', unsafe_allow_html=True)

        else:
            st.markdown(
                # 背景を青にして、角を丸くする
                f'<div style="background-color: #cfe2ff; border-radius: 10px; padding: 10px;">'
                f"チャットボット: {chat['content']}"
                '</div>', unsafe_allow_html=True)

st.title('RAG機能付きチャットボットを作ろう')
st.write('streamlitを使ったUIの作成')

# チャット履歴を表示
display_chat_history()

prompt = st.text_area('プロンプト入力欄', )

button1, button2, button3 = st.columns(3)
if button1.button('チャット'):
    chat_response = default_chat(client, prompt)
    # チャット履歴に追加
    # ユーザーの入力を追加、roleはuser
    st.session_state.chat_history.append({"role": "user", "content": prompt})
    # チャットボットの返答を追加、roleはsystem
    st.session_state.chat_history.append({"role": "system", "content": chat_response})
    st.rerun()
if button2.button('RAG'):
    # 
    RAG_response = chat_with_retriever(client, retriever, prompt)
    list = RAG_response.output
    st.session_state.chat_history.append({"role": "user", "content": prompt})
    st.session_state.chat_history.append({"role": "system", 
                                          "content": list})
    if 'RAG_list' not in st.session_state:
        st.session_state.RAG_list = []
    
    for i in list:
        st.session_state.RAG_list.append(i.dict())
    
    st.session_state.df = pd.DataFrame(st.session_state.RAG_list)
    
    
    st.rerun()

if button3.button('PDFをベクトル化'):
    pdf_to_vector()
    st.rerun()

button_to_csv = st.button('CSVに保存')
if button_to_csv:
    st.session_state.df.to_csv("output.csv", index=False)

if 'df' in st.session_state:
    st.table(st.session_state.df)

次のような画面で、構造化された回答を表示します。
テーブルを含めて表示できていれば、成功です。

Pydanticを使って、構造化出力を定義

以下のように、構造化出力を定義します。
ここでは、query, source, answer, summaryの4つの要素を持つ構造を定義し、それをリスト化しています。
なお、BaseModelはPydanticのクラスで、FieldはPydanticのフィールドを定義するためのクラスです。BaseModelでは、複数要素を持つリストなどネスト(入れ子)構造を一度には定義できず、まずstructured_RAG_outputのような構造で定義し、そのあとに、要素がstructured_RAG_outputであるリストをoutputとして定義しています。

class structured_RAG_output(BaseModel):
    query: str = Field (..., description="The query from user")
    source: str = Field (..., description="The source file or URL of the document")
    answer: str = Field (..., description="The answer to the question")
    summary: str = Field (..., description="The summary of the answer")

class list_of_output(BaseModel):
    output: List[structured_RAG_output]

gpt-4o-miniを使って、回答を構造化出力に変換

以下のように、client.beta.chat.completions.parseを使って、回答を構造化出力に変換しています。

def chat_with_retriever(client, retriever, prompt):
    retriever_response = retriever.invoke(prompt)
    RAG_response = client.beta.chat.completions.parse(
        model="gpt-4o-mini", # 好きなモデルを選択
        messages=[
            {"role": "system", "content": """
            You are AI assistant.
            User will give you documents and question.
            You need to find the answer in the documents.
            You should NOT answer besides the documents.
            If you can not find the answer, you can ask the user for more information.
            """},
            {"role": "user", "content": f"""
            You are given document as below.
            {retriever_response}
            """},
            {"role": "assistant", "content": "I will find the answer in the document."},
            {"role": "user", "content": prompt},
        ],
        response_format=list_of_output,
        temperature=0.0,
    )
    lists = RAG_response.choices[0].message.parsed
    return lists

なお、response_format=list_of_outputを指定しています。
本機能はgpt-4o-2024-08-06以降で利用可能です。
通常のchat.completions.createでは
RAG_response.choices[0].message.contentによって回答を取得していましたが、Structured Outputの場合は、RAG_response.choices[0].message.parsedによって構造化された回答を取得しています。

構造化した回答をdict形式にし、その後pandasのDataFrameに変換

以下のように、構造化した回答をdict形式にし、その後pandasのDataFrameに変換しています。
RAG_response.choices[0].message.parsed.outputがリストになっており、その各要素をdict()で辞書に変換しながら空のリストに追加しています。
その後、pd.DataFrameでDataFrameに変換しています。

list = RAG_response.output

if 'RAG_list' not in st.session_state:
    st.session_state.RAG_list = []

for i in list:
    st.session_state.RAG_list.append(i.dict())

st.session_state.df = pd.DataFrame(st.session_state.RAG_list)

streamlitのtableを使って、DataFrameを表示

以下のように、streamlitのtableを使って、DataFrameを表示しています。

if 'df' in st.session_state:
    st.table(st.session_state.df)

Dataframeをcsvファイルに保存

以下のように、csvファイルに保存するボタンを追加しています。

button_to_csv = st.button('CSVに保存')
if button_to_csv:
    st.session_state.df.to_csv("output.csv", index=False)

名称を例えば Datetimeなどで日付を取得して、その日付をファイル名にすることで、ファイル名を変更することもできます。

リンク

Discussion