😗

【Python】Ollama 向けにAI Chat UIを作成してみた

2024/08/17に公開

はじめに

ローカルLLMでは主にOllamaを使っていますが、gemma2などのモデルを試す時に自作のチャットUIがあったらいいなと思って、AIと一緒に作成してみました。このアプリはPythonでLangChainを使用し、UIはcustomtkinterを使っています。

モデル選択はollama listで取得したものが選択できます。

環境

  • Mac mini M2 16GB
  • python 3.11.6
  • LangChain 0.2.14

インストール

venv で環境作成してからインストール

pip install langchain langchain-community langchain-core
pip install customtkinter

コード全文

import asyncio
import subprocess
import threading
import uuid

import customtkinter as ctk
from langchain_community.chat_models import ChatOllama
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_community.chat_message_histories import ChatMessageHistory


class ChatApp(ctk.CTk):
    def __init__(self):
        super().__init__()

        self.title("AI Chat")
        self.geometry("600x600")

        self.model_list = self.get_ollama_models()

        if not self.model_list:
            self.show_warning(
                "No Ollama models found. Please install Ollama and download at least one model."
            )
            self.current_model = None
            self.model = None
        else:
            self.current_model = self.model_list[0]
            self.model = ChatOllama(model=self.current_model, streaming=True)

        self.prompt = ChatPromptTemplate.from_messages(
            [
                SystemMessage(content="あなたは素敵なアシスタントです"),
                MessagesPlaceholder(variable_name="history"),
                ("human", "{input}"),
            ]
        )

        if self.model:
            self.chain = self.prompt | self.model | StrOutputParser()
        else:
            self.chain = None

        self.initialize_session()

        self.create_widgets()

        self.loop = asyncio.new_event_loop()
        asyncio.set_event_loop(self.loop)

        self.thread = threading.Thread(target=self.run_event_loop, daemon=True)
        self.thread.start()

    def initialize_session(self):
        self.session_id = str(uuid.uuid4())
        self.chat_history = ChatMessageHistory()

    def get_ollama_models(self):
        try:
            result = subprocess.run(["ollama", "list"], capture_output=True, text=True)
            models = [line.split()[0] for line in result.stdout.splitlines()[1:]]
            return models
        except Exception:
            return []

    def create_widgets(self):
        top_frame = ctk.CTkFrame(self)
        top_frame.pack(fill="x", padx=10, pady=10)

        if self.model_list:
            self.model_var = ctk.StringVar(value=self.current_model)
            model_menu = ctk.CTkOptionMenu(
                top_frame,
                values=self.model_list,
                variable=self.model_var,
                command=self.change_model,
            )
            model_menu.pack(side="left")

        clear_button = ctk.CTkButton(
            top_frame, text="Clear History", command=self.clear_history
        )
        clear_button.pack(side="right")

        self.chat_display = ctk.CTkTextbox(self, state="disabled", wrap="word")
        self.chat_display.pack(expand=True, fill="both", padx=10, pady=10)

        input_frame = ctk.CTkFrame(self)
        input_frame.pack(fill="x", padx=10, pady=10)

        self.input_field = ctk.CTkEntry(input_frame)
        self.input_field.pack(side="left", expand=True, fill="x")

        send_button = ctk.CTkButton(input_frame, text="Send", command=self.send_message)
        send_button.pack(side="right", padx=(10, 0))

        self.input_field.bind("<Return>", lambda event: self.send_message())

    def change_model(self, selected_model):
        self.current_model = selected_model
        self.model = ChatOllama(model=self.current_model, streaming=True)
        self.chain = self.prompt | self.model | StrOutputParser()

    def clear_history(self):
        self.initialize_session()
        self.chat_display.configure(state="normal")
        self.chat_display.delete("1.0", "end")
        self.chat_display.configure(state="disabled")

    def send_message(self):
        user_input = self.input_field.get()
        if user_input:
            if not self.model:
                self.show_warning(
                    "No Ollama model available. Please install Ollama and download at least one model."
                )
                return

            self.input_field.delete(0, "end")
            self.display_message("You: " + user_input + "\n", new_message=True)
            self.chat_history.add_message(HumanMessage(content=user_input))

            asyncio.run_coroutine_threadsafe(
                self.get_ai_response(user_input), self.loop
            )

    async def get_ai_response(self, user_input):
        self.display_message("AI: ", new_message=True, end="")
        full_response = ""
        buffer = ""
        async for chunk in self.chain.astream(
            {"input": user_input, "history": self.chat_history.messages}
        ):
            full_response += chunk
            buffer += chunk
            if chunk.strip():
                self.display_message(buffer, new_message=False, end="")
                buffer = ""
                await asyncio.sleep(0.01)

        if buffer:
            self.display_message(buffer, new_message=False, end="")

        self.remove_last_newline()

        self.chat_history.add_message(AIMessage(content=full_response.strip()))

    def display_message(self, message, new_message=False, end=""):
        self.chat_display.configure(state="normal")
        if new_message and self.chat_display.index("end-1c") != "1.0":
            current_text = self.chat_display.get("1.0", "end-1c")
            if not current_text.endswith("\n"):
                self.chat_display.insert("end", "\n")

        self.chat_display.insert("end", message + end)
        self.chat_display.configure(state="disabled")
        self.chat_display.see("end")
        self.update()

    def remove_last_newline(self):
        self.chat_display.configure(state="normal")
        if self.chat_display.get("end-2c", "end-1c") == "\n":
            self.chat_display.delete("end-2c", "end-1c")
        self.chat_display.configure(state="disabled")

    def show_warning(self, message):
        ctk.CTkMessagebox(title="Warning", message=message, icon="warning")

    def run_event_loop(self):
        asyncio.set_event_loop(self.loop)
        self.loop.run_forever()

    def on_closing(self):
        self.loop.call_soon_threadsafe(self.loop.stop)
        self.thread.join()
        self.destroy()


if __name__ == "__main__":
    app = ChatApp()
    app.protocol("WM_DELETE_WINDOW", app.on_closing)
    app.mainloop()

例えばchat_app.pyとして保存し、
python chat_app.pyすると、アプリが起動します。

※ 以下、コード解説はAIに書いてもらったものをベースにしています。

1. 概要

このアプリケーションは以下の主要な機能を持っています:

  • Ollama モデルを使用したチャット機能
  • カスタム Tkinter を使用した GUI インターフェース
  • 非同期処理によるレスポンシブなユーザーエクスペリエンス
  • チャット履歴の管理
  • 複数の Ollama モデル間の切り替え機能

2. 主要なライブラリとフレームワーク

import asyncio
import subprocess
import threading
import uuid

import customtkinter as ctk
from langchain_community.chat_models import ChatOllama
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_community.chat_message_histories import ChatMessageHistory
  • asynciothreading: 非同期処理とマルチスレッディングのサポート
  • uuid: ユニークなセッション ID の生成に使用
  • customtkinter: 現代的なルックアンドフィールを持つ GUI を作成するために使用
  • langchain: Ollama モデルとのインタラクション、プロンプト管理、チャット履歴の処理に使用

3. アプリケーションの構造

アプリケーションはChatAppクラスとして実装されています。主要なコンポーネントは以下の通りです:

3.1 初期化とセットアップ

class ChatApp(ctk.CTk):
    def __init__(self):
        super().__init__()
        # ... (初期化コード)
  • ウィンドウのタイトルとサイズの設定
  • 利用可能な Ollama モデルのリストの取得
  • LangChain を使用したチャットモデルとプロンプトテンプレートの設定
  • GUI ウィジェットの作成
  • 非同期イベントループのセットアップ

3.2 Ollama モデルの管理

def get_ollama_models(self):
    try:
        result = subprocess.run(["ollama", "list"], capture_output=True, text=True)
        models = [line.split()[0] for line in result.stdout.splitlines()[1:]]
        return models
    except Exception:
        return []

この関数はollama listコマンドを実行して、利用可能な Ollama モデルのリストを取得します。

3.3 GUI の構築

def create_widgets(self):
    # ... (GUIウィジェットの作成コード)

この関数は以下のような GUI 要素を作成します:

  • モデル選択ドロップダウン
  • チャット履歴クリアボタン
  • チャット表示エリア
  • メッセージ入力フィールドと送信ボタン

3.4 チャット機能の実装

def send_message(self):
    # ... (メッセージ送信の処理)

async def get_ai_response(self, user_input):
    # ... (AI応答の非同期処理)

これらの関数は、ユーザー入力の処理、AI モデルからの応答の取得、そしてそれらの表示を担当します。get_ai_response関数は非同期で動作し、モデルからのストリーミングレスポンスを処理します。

3.5 非同期処理の管理

def run_event_loop(self):
    asyncio.set_event_loop(self.loop)
    self.loop.run_forever()

この関数は別スレッドで非同期イベントループを実行し、GUI のレスポンシブ性を維持します。

4. 主要な機能の詳細

4.1 モデルの切り替え

def change_model(self, selected_model):
    self.current_model = selected_model
    self.model = ChatOllama(model=self.current_model, streaming=True)
    self.chain = self.prompt | self.model | StrOutputParser()

ユーザーが異なるモデルを選択した際に、この関数が呼び出されます。新しいモデルで ChatOllama インスタンスを再作成し、処理チェーンを更新します。

4.2 チャット履歴の管理

def initialize_session(self):
    self.session_id = str(uuid.uuid4())
    self.chat_history = ChatMessageHistory()

def clear_history(self):
    self.initialize_session()
    self.chat_display.configure(state="normal")
    self.chat_display.delete("1.0", "end")
    self.chat_display.configure(state="disabled")

これらの関数は、新しいセッションの開始とチャット履歴のクリアを管理します。

4.3 ストリーミングレスポンスの処理

async for chunk in self.chain.astream(
    {"input": user_input, "history": self.chat_history.messages}
):
    full_response += chunk
    buffer += chunk
    if chunk.strip():
        self.display_message(buffer, new_message=False, end="")
        buffer = ""
        await asyncio.sleep(0.01)

この部分は、AI モデルからのストリーミングレスポンスを処理し、リアルタイムでユーザーに表示します。

まとめ

とりあえずのチャットUIをOllama用に自作するなら最低限として、

  • モデル選択
  • チャット履歴管理
  • ストリーミング回答

は欲しいなと思ってAIと一緒に作ってみました。
さらに機能を追加していくなら、SQLiteでデータ管理したり、Web検索なども取り入れていくと良いかもですね。

Discussion