Zenn
📁

Streamlitの新機能:chat_input()でファイルアップロードが可能に!マルチモーダルチャットアプリの作り方

2025/03/21に公開

はじめに

こんにちは或いはこんばんは。Ryuzakiです。

先日(2025年3月4日)に行われたStreamlitのアップデートで、ついにchat_input()コンポーネントがファイルアップロードに対応しました!これにより、公式のコンポーネントだけを使って簡易的なマルチモーダルチャットアプリが作れるようになりました。

ただし、このコンポーネントには若干のクセがあり、自分で実装してみた際に少し苦戦しました。せっかくの機会なので、新しくなったchat_input()の使い方を説明するとともに、これを用いた簡単なマルチモーダルチャットアプリの作り方をゼロから解説していきたいと思います。

前提条件

まずは今回の検証環境について説明します。OSやPythonのバージョンによる差異はないと思いますが、念のため記載しておきます。

また、以下の準備が済んでいることを前提とします。

  • Poetryがインストール済み
  • OpenAIのAPIキーを取得済み

セットアップ

それでは早速、プロジェクトのセットアップを行っていきましょう。

1. Poetryプロジェクトの初期化

poetry init  # 初期設定はすべて適当でOK

2. 各種ライブラリをインストール

poetry add streamlit openai tiktoken
検証環境のライブラリバージョン
  • streamlit: 1.43.2
  • openai: 1.66.3
  • tiktoken: 0.9.0

3. 環境変数の設定

.envファイルを作成し、OpenAIのAPIキーを設定します。

.env
OPENAI_API_KEY="sk-proj-..."

.envファイルが自動的に読み込まれるように設定します。

poetry self add poetry-dotenv-plugin

アプリケーション概要

今回作成するアプリケーションの概要を説明します。

ディレクトリ構成

streamlit-multimodal-sample/
├── app.py                      # Streamlitアプリケーションのメインコード
├── pyproject.toml              # Poetryの設定ファイル(依存関係の定義)
├── poetry.lock                 # 依存関係の正確なバージョンを固定するlockファイル
├── .env                        # 環境変数を定義するファイル(APIキーなど)
├── .gitignore                  # Git管理から除外するファイル指定
└── README.md                   # プロジェクトの説明

クラス図

完成形のソースコードはGitHubで公開していますので、詳細はそちらをご覧ください。

https://github.com/RyuzakiShinji/streamlit-multimodal-sample

実装手順

それでは、実際の実装手順を説明していきます。

1. セッション管理(チャット履歴の保存処理)実装

まずは、チャット履歴を保存するためのセッション管理機能を実装します。

1-1. セッションの初期化

チャット履歴を保管するセッションキーであるchat_messagesを初期化する関数(initialize_session())を実装します。

セッション初期化の実装コード
app.py
import streamlit as st

class SessionManager:
    @staticmethod
    def initialize_session() -> None:
        if "chat_messages" not in st.session_state:
            st.session_state.chat_messages = []

1-2. チャット履歴の追加

チャット履歴の追加を行う関数(add_message())を実装します。

チャット履歴追加の実装コード
app.py
import datetime
from typing import Any, Dict, List, Optional, Tuple
import uuid

# Type definitions for better code clarity
ChatMessage = Dict[str, Any]

class SessionManager:    
    @staticmethod
    def add_message(role: str, content: str, files: Optional[List[Any]] = None) -> None:
        message_id = str(uuid.uuid4())
        timestamp = datetime.datetime.now().isoformat()

        chat_message: ChatMessage = {
            "id": message_id,
            "role": role,
            "content": content,
            "timestamp": timestamp,
        }

        if files:
            chat_message["uploaded_filenames"] = [file.name for file in files]

        st.session_state.chat_messages.append(chat_message)

ファイルの保存方法に関しては要件に合わせて設定します。ここでは、シンプルにアップロードされたファイル名のみを保持しています。

1-3. チャット履歴の取得

チャット履歴の取得を行う関数(get_chat_history())を実装します。

チャット履歴取得の実装コード
app.py
class SessionManager: 
    @staticmethod
    def get_chat_history() -> List[ChatMessage]:
        return st.session_state.chat_messages

2. アップロードされたファイルの処理実装

次に、アップロードされたファイルを処理する機能を実装します。

2-1. ユーザー入力とファイルの取得

chat_input()からユーザーが入力した値とアップロードしたファイルを取得する関数(process_user_input())を実装します。

ユーザー入力処理の実装コード
app.py
from streamlit.elements.widgets.chat import ChatInputValue

class InputHandler:
    @staticmethod
    def process_user_input(user_input: Any) -> Tuple[str, List[Any]]:
        if isinstance(user_input, str):
            return user_input, []
        elif isinstance(user_input, ChatInputValue):
            return user_input.text, user_input.files or []
        else:
            raise ValueError(f"Unexpected input type: {type(user_input)}")

2-2. 画像のエンコード

アップロードされたファイルをそれぞれbase64形式に変換する関数(encode_images())を実装します。

画像エンコードの実装コード
app.py
import base64
import logging

# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

# Type definitions for better code clarity
ImageData = Dict[str, str]

class InputHandler:
    @staticmethod
    def encode_images(files: List[Any]) -> List[ImageData]:
        if not files:
            return []

        image_data = []
        for file in files:
            try:
                encoded_data = base64.b64encode(file.read()).decode("utf-8")
                file_type = file.type.split("/")[-1]  # Extract format from MIME type
                image_data.append({"type": file_type, "data": encoded_data})
            except Exception as e:
                logger.error(f"Error processing file {file.name}: {str(e)}")
                # Continue processing other files even if one fails

        return image_data

3. マルチモーダルのリクエスト処理実装

マルチモーダルのリクエスト処理を実装します。

3-1. メッセージの構築

OpenAI APIへ渡すメッセージの構築を行う関数(generate_user_message())を実装します。

メッセージ構築の実装コード
app.py
import tiktoken
from tiktoken.core import Encoding

# Type definitions for better code clarity
ContentItem = Dict[str, Any]

class TokenManager:
    def __init__(self, tokenizer: Encoding, max_tokens: int) -> None:
        self.tokenizer = tokenizer
        self.max_tokens = max_tokens

    def count_tokens(self, text: str) -> int:
        return len(self.tokenizer.encode(text))

    def format_chat_history(self, chat_messages: List[ChatMessage]) -> str:
        token_count: int = 0
        formatted_history: str = ""

        # Process messages from newest to oldest to prioritize recent context
        for message in reversed(chat_messages):
            message_text = f"{message['role']}: {message['content']}\n"
            message_tokens = self.count_tokens(message_text)

            # Check if adding this message would exceed the token limit
            if token_count + message_tokens > self.max_tokens:
                break

            # Add message to history and update token count
            token_count += message_tokens
            formatted_history = message_text + formatted_history

        return formatted_history

class PromptGenerator:
    def __init__(self, token_manager: TokenManager) -> None:
        self.token_manager = token_manager

    def generate_enhanced_prompt(self, prompt: str, chat_history: Optional[List[ChatMessage]] = None) -> str:
        if not chat_history:
            return prompt

        formatted_history = self.token_manager.format_chat_history(chat_history)
        enhanced_prompt = f"""
        Generate a response for the user considering the prompt and the conversation history.

        Chat history:
        {formatted_history}

        User's prompt:
        {prompt}
        """
        return enhanced_prompt

    def sanitize_prompt(self, prompt: str) -> str:
        # This is a basic implementation - a production system would need more robust checks
        dangerous_patterns = [
            "ignore previous instructions",
            "ignore all previous prompts",
            "disregard your instructions",
        ]

        sanitized_prompt = prompt
        for pattern in dangerous_patterns:
            if pattern.lower() in prompt.lower():
                logger.warning(f"Potentially dangerous prompt detected: {pattern}")
                sanitized_prompt = sanitized_prompt.replace(pattern, "[FILTERED]")

        return sanitized_prompt

class MessageProcessor:
    def __init__(self, token_manager: TokenManager) -> None:
        self.token_manager = token_manager
        self.prompt_generator = PromptGenerator(token_manager)

    def generate_user_message(
        self, prompt: str, chat_history: Optional[List[ChatMessage]] = None, images: Optional[List[ImageData]] = None
    ) -> Dict[str, Any]:
        content: List[ContentItem] = []

        # Add images if provided
        if images:
            for image in images:
                content.append(
                    {"type": "image_url", "image_url": {"url": f"data:image/{image['type']};base64,{image['data']}"}}
                )

        # Generate enhanced prompt with context and sanitization
        enhanced_prompt = self.prompt_generator.generate_enhanced_prompt(prompt, chat_history)
        sanitized_prompt = self.prompt_generator.sanitize_prompt(enhanced_prompt)

        # Add text content as the first item
        content.insert(0, {"type": "text", "text": sanitized_prompt})

        return {"role": "user", "content": content}

3-2. OpenAI APIのリクエスト処理

OpenAI APIのリクエスト処理を行う関数(generate_response())を実装します。

OpenAI API リクエスト処理の実装コード
app.py
# Constants
MODEL_NAME = "gpt-4o-mini"

class OpenAIClient:
    def __init__(self) -> None:
        self.client = OpenAI()

    def generate_response(self, user_message: Dict[str, Any]) -> str:
        try:
            response = self.client.chat.completions.create(model=MODEL_NAME, messages=[user_message])
            return response.choices[0].message.content
        except Exception as e:
            logger.error(f"Error generating response from OpenAI: {str(e)}")
            raise Exception(f"Failed to get response from AI model: {str(e)}")

4. UI実装

最後に、UIを実装していきます。

4-1. ページ情報の設定

ページ情報の設定を行う関数(setup_page())を実装します。

ページ設定の実装コード
app.py
# Constants
MODEL_MAX_INPUT_TOKEN = 128000
TOKENIZER = tiktoken.encoding_for_model("gpt-4o")

class ChatUI:
    def __init__(self) -> None:
        self.openai_client = OpenAIClient()
        token_manager = TokenManager(TOKENIZER, MODEL_MAX_INPUT_TOKEN)
        self.message_processor = MessageProcessor(token_manager)
        self.file_handler = InputHandler()

    def setup_page(self) -> None:
        st.set_page_config(page_title="Multimodal Chat Application", page_icon="💬", layout="centered")
        st.title("Multimodal Chat Application")
        st.subheader("Chat with AI using text and images")

4-2. チャット履歴の表示

チャット履歴の表示を行う関数(display_chat_history())を実装します。

チャット履歴表示の実装コード
app.py
class ChatUI:
    def display_chat_history(self) -> None:
        for message in SessionManager.get_chat_history():
            with st.chat_message(message["role"]):
                st.markdown(message["content"])

                # Display attached files if present
                if message["role"] == "user" and message.get("uploaded_filenames"):
                    with st.expander("Attached Files", expanded=False):
                        for filename in message["uploaded_filenames"]:
                            st.caption(filename)

ユーザーがアップロードしたファイル情報の表示方法は要件に合わせて設定します。ここではシンプルにアップロードされたファイル名のみをst.expander()を用いて表示しています。

4-3. ユーザー入力処理

ユーザーの入力欄の表示、および、LLMへのリクエストを行う関数(handle_user_input())を実装します。

ユーザー入力処理の実装コード
app.py
# Constants
ALLOWED_FILE_TYPES = ["jpg", "jpeg", "png"]

class ChatUI:
    def handle_user_input(self) -> None:
        user_input = st.chat_input(
            "Type your message here...",
            accept_file="multiple",
            file_type=ALLOWED_FILE_TYPES,
        )

        if not user_input:
            return

        try:
            # Process the user input
            prompt, files = self.file_handler.process_user_input(user_input)

            # Extract and encode images
            images = self.file_handler.encode_images(files)

            # Display user message
            with st.chat_message("user"):
                st.markdown(prompt)
                if files:
                    with st.expander("Attached Files", expanded=False):
                        for file in files:
                            st.caption(file.name)

            # Generate and display AI response
            with st.spinner("Generating response..."):
                try:
                    # Prepare message for the API
                    user_message = self.message_processor.generate_user_message(
                        prompt, SessionManager.get_chat_history(), images
                    )

                    # Get response from OpenAI
                    response_text = self.openai_client.generate_response(user_message)

                    # Display the response
                    with st.chat_message("assistant"):
                        st.markdown(response_text)

                    # Update chat history
                    SessionManager.add_message("user", prompt, files)
                    SessionManager.add_message("assistant", response_text)

                except Exception as e:
                    st.error(f"Error: {str(e)}")
                    logger.error(f"Error in AI response generation: {str(e)}")

        except Exception as e:
            st.error(f"An error occurred: {str(e)}")
            logger.error(f"Error in input processing: {str(e)}")

4-4. メイン関数

上記の関数をmain()で順々に呼び出します。

メイン関数の実装コード
app.py
def main() -> None:
    try:
        # Initialize session state
        SessionManager.initialize_session()

        # Set up and run the chat UI
        chat_ui = ChatUI()
        chat_ui.setup_page()
        chat_ui.display_chat_history()
        chat_ui.handle_user_input()

    except Exception as e:
        st.error(f"Application error: {str(e)}")
        logger.critical(f"Critical application error: {str(e)}")

if __name__ == "__main__":
    main()

おわりに

今回は、先日のStreamlitアップデートで追加されたchat_input()のファイルアップロード機能を使って、マルチモーダルチャットアプリの実装方法を解説しました。

コードの全体像を見たい方は、以下のGitHubリポジトリをご覧ください。このサンプルがマルチモーダルアプリケーションを開発する際に、少しでも参考になれば嬉しいです。

https://github.com/RyuzakiShinji/streamlit-multimodal-sample

ここまでお読みいただき、ありがとうございました。

参考リソース

Discussion

ログインするとコメントできます