🦙

LlamaIndexとGradioでお手軽チャットボットを作る【Windows 11 + WSL 2 + JupyterLab】

2024/04/16に公開

はじめに

前回、LlamaIndexを使用したRAGを構築しました。今回は、Gradioを使ってWebアプリ上でChatGPTのようなチャットボットを開発しました。

前回の記事
https://zenn.dev/toki_mwc/articles/1485f655611d54

Gradioとは

Gradioは、Pythonで書かれた機械学習モデルをWebアプリとして簡単にデモできるライブラリです。数行のコードで、予測や分類などの機能を持つインターフェースを作成し、他の人と共有することができます。Hugging Face Spacesで公開することも可能です。複雑な開発は不要で、データサイエンスに興味がある方にとって、手軽にモデルを試す方法を提供します。

画像生成でStable Diffusion WebUIを使用している方にとっては馴染みのあるものですね。

準備

Gradioのインストール

pipでインストールする場合

pip install gradio

condaでインストールする場合

conda install -c conda-forge gradio

また、llamaindexのインストールについては前回の記事を参照してください。

チャットボットを実装する

コードの全容はこちらです。実行すると出力画面にRunning on local URL: http://127.0.0.1:7870と表示されますのでURLをクリックしてください。

import gradio as gr
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex
from llama_index.llms.ollama import Ollama
from llama_index.embeddings.huggingface import HuggingFaceEmbedding

# Ollamaクライアントの初期化
llm = Ollama(model="llama2", request_timeout=30.0)

# HuggingFaceEmbeddingの初期化
embed_model = HuggingFaceEmbedding(model_name="all-MiniLM-L6-v2")

# テキストファイルの読み込みとインデックス作成
reader = SimpleDirectoryReader(input_files=["./data/document.pdf"])
data = reader.load_data()
index = VectorStoreIndex.from_documents(data, embed_model=embed_model)

# クエリエンジンの初期化
query_engine = index.as_query_engine(llm=llm, streaming=True, similarity_top_k=3)

def add_text(history, text):
    history = history + [(text, None)]
    return history, gr.Textbox(value="", interactive=False)

def bot(history):
    query = history[-1][0]
    response = query_engine.query(query)
    
    result = []
    source_text = []
    for node in response.source_nodes:
        source_text.append(f"Text: {node.node.get_content().strip().replace('\\n', ' ')[:100]}...\nScore: {node.score:.3f}")
    
    history[-1][1] = ""
    for character in str(response):
        history[-1][1] += character
        yield history
    
    yield history + [("ソースノード:", "\n\n".join(source_text))]

with gr.Blocks() as demo:
    chatbot = gr.Chatbot([],
        bubble_full_width=False,
        avatar_images=(None, "icon_logo.png"),
    )
    with gr.Row():
        txt = gr.Textbox(
            scale=4,
            show_label=False,
            container=False
        )
    clear = gr.Button("Clear")

    txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(bot, chatbot, chatbot)
    txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
    clear.click(lambda: None, None, chatbot, queue=False)

demo.queue()
demo.launch(debug=True)

LlamaIndexとOllamaの初期化

import gradio as gr
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex
from llama_index.llms.ollama import Ollama
from llama_index.embeddings.huggingface import HuggingFaceEmbedding

# Ollamaクライアントの初期化
llm = Ollama(model="llama2", request_timeout=30.0)
# HuggingFaceEmbeddingの初期化
embed_model = HuggingFaceEmbedding(model_name="all-MiniLM-L6-v2")
  • Ollamaクラスを使用して、Ollamaクライアントを初期化します。model引数で使用するモデル(ここでは"llama2")を指定し、request_timeout引数でリクエストのタイムアウト時間を設定します。
  • HuggingFaceEmbeddingクラスを使用して、HuggingFaceのEmbeddingモデルを初期化します。model_name引数で使用するモデル(ここでは"all-MiniLM-L6-v2")を指定します。

PDFファイルの読み込みとインデックス作成:

# テキストファイルの読み込みとインデックス作成
reader = SimpleDirectoryReader(input_files=["./data/document.pdf"])
data = reader.load_data()
index = VectorStoreIndex.from_documents(data, embed_model=embed_model)
  • SimpleDirectoryReaderクラスを使用して、指定したディレクトリ内のPDFファイルを読み込みます。テキストファイルでもOK。
  • 読み込んだデータをVectorStoreIndexクラスに渡し、Embeddingモデルを使用してインデックスを作成します。

クエリエンジンの初期化

# クエリエンジンの初期化
query_engine = index.as_query_engine(llm=llm, streaming=True, similarity_top_k=3)
  • index.as_query_engineメソッドを使用して、作成したインデックスからクエリエンジンを初期化します。
  • llm引数で初期化したOllamaクライアントを指定し、streaming引数でストリーミングを有効にします。similarity_top_k引数で、類似度の高い上位3つの結果を返すように設定します。

ユーザー入力の処理

def add_text(history, text):
    history = history + [(text, None)]
    return history, gr.Textbox(value="", interactive=False)
  • add_text関数を定義して、ユーザーの入力をチャットボットの履歴に追加します。
  • txt.submitイベントリスナーを使用して、ユーザーが入力を送信したときにadd_text関数を呼び出します。

チャットボットの応答生成

def bot(history):
    query = history[-1][0]
    response = query_engine.query(query)
    result = []
    source_text = []
    for node in response.source_nodes:
        source_text.append(f"Text: {node.node.get_content().strip().replace('\\n', ' ')[:100]}...\nScore: {node.score:.3f}")
    history[-1][1] = ""
    for character in str(response):
        history[-1][1] += character
        yield history
    yield history + [("ソースノード:", "\n\n".join(source_text))]
  • bot関数を定義して、ユーザーの質問に対する応答を生成します。
  • query_engine.queryメソッドを使用して、ユーザーの質問に対する応答を取得します。
  • 応答のソースノードを解析し、関連する情報を抽出します。
  • 応答を1文字ずつ生成し、チャットボットの履歴に追加します。
  • 使用されたソースノードの情報を新しいメッセージとしてチャットボットに追加します。

Gradioアプリケーションの構築

with gr.Blocks() as demo:
    chatbot = gr.Chatbot([],
        bubble_full_width=False,
        avatar_images=(None, "logo.png"),
    )
    with gr.Row():
        txt = gr.Textbox(
            scale=4,
            show_label=False,
            container=False
        )
    clear = gr.Button("Clear")
  • gr.Blocksを使用してアプリケーションの全体構成を定義します。
  • gr.Chatbotを使用してチャットボットを作成し、avatar_images引数でアバター画像を指定します。チャットボットのアイコンです。画像をこのコードと同じフォルダ内に入れると表示されました。
  • gr.Rowとgr.Textboxを使用して、ユーザー入力用のテキストボックスを作成します。
  • gr.Buttonを使用して、チャットボットをクリアするためのボタンを作成します。

イベントリスナーの設定

    txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(bot, chatbot, chatbot)
    txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
    clear.click(lambda: None, None, chatbot, queue=False)
  • txt_msg変数を使用して、txt.submitイベントとbot関数の結果をチェーンします。
  • txt_msg.thenメソッドを使用して、bot関数の実行後にテキストボックスを再度インタラクティブにします。
  • clear.clickイベントリスナーを使用して、クリアボタンがクリックされたときにチャットボットの履歴をクリアします。

アプリケーションの起動

demo.queue()
demo.launch(debug=True)
  • demo.queueメソッドを使用して、リクエストのキューイングを有効にします。
  • demo.launchメソッドを使用して、アプリケーションを起動します。debug引数をTrueに設定して、デバッグモードを有効にします。

参考記事

https://bou7254.com/posts/google-colab-web-app-gradio
https://qiita.com/DeepTama/items/1a44ddf6325c2b2cd030
https://www.gradio.app/docs/chatbot

Discussion