Streamlitの新機能:chat_input()でファイルアップロードが可能に!マルチモーダルチャットアプリの作り方
はじめに
こんにちは或いはこんばんは。Ryuzakiです。
先日(2025年3月4日)に行われたStreamlitのアップデートで、ついにchat_input()
コンポーネントがファイルアップロードに対応しました!これにより、公式のコンポーネントだけを使って簡易的なマルチモーダルチャットアプリが作れるようになりました。
ただし、このコンポーネントには若干のクセがあり、自分で実装してみた際に少し苦戦しました。せっかくの機会なので、新しくなったchat_input()
の使い方を説明するとともに、これを用いた簡単なマルチモーダルチャットアプリの作り方をゼロから解説していきたいと思います。
前提条件
まずは今回の検証環境について説明します。OSやPythonのバージョンによる差異はないと思いますが、念のため記載しておきます。
また、以下の準備が済んでいることを前提とします。
- Poetryがインストール済み
- OpenAIのAPIキーを取得済み
セットアップ
それでは早速、プロジェクトのセットアップを行っていきましょう。
1. Poetryプロジェクトの初期化
poetry init # 初期設定はすべて適当でOK
2. 各種ライブラリをインストール
poetry add streamlit openai tiktoken
検証環境のライブラリバージョン
- streamlit: 1.43.2
- openai: 1.66.3
- tiktoken: 0.9.0
3. 環境変数の設定
.env
ファイルを作成し、OpenAIのAPIキーを設定します。
OPENAI_API_KEY="sk-proj-..."
.env
ファイルが自動的に読み込まれるように設定します。
poetry self add poetry-dotenv-plugin
アプリケーション概要
今回作成するアプリケーションの概要を説明します。
ディレクトリ構成
streamlit-multimodal-sample/
├── app.py # Streamlitアプリケーションのメインコード
├── pyproject.toml # Poetryの設定ファイル(依存関係の定義)
├── poetry.lock # 依存関係の正確なバージョンを固定するlockファイル
├── .env # 環境変数を定義するファイル(APIキーなど)
├── .gitignore # Git管理から除外するファイル指定
└── README.md # プロジェクトの説明
クラス図
完成形のソースコードはGitHubで公開していますので、詳細はそちらをご覧ください。
実装手順
それでは、実際の実装手順を説明していきます。
1. セッション管理(チャット履歴の保存処理)実装
まずは、チャット履歴を保存するためのセッション管理機能を実装します。
1-1. セッションの初期化
チャット履歴を保管するセッションキーであるchat_messages
を初期化する関数(initialize_session()
)を実装します。
セッション初期化の実装コード
import streamlit as st
class SessionManager:
@staticmethod
def initialize_session() -> None:
if "chat_messages" not in st.session_state:
st.session_state.chat_messages = []
1-2. チャット履歴の追加
チャット履歴の追加を行う関数(add_message()
)を実装します。
チャット履歴追加の実装コード
import datetime
from typing import Any, Dict, List, Optional, Tuple
import uuid
# Type definitions for better code clarity
ChatMessage = Dict[str, Any]
class SessionManager:
@staticmethod
def add_message(role: str, content: str, files: Optional[List[Any]] = None) -> None:
message_id = str(uuid.uuid4())
timestamp = datetime.datetime.now().isoformat()
chat_message: ChatMessage = {
"id": message_id,
"role": role,
"content": content,
"timestamp": timestamp,
}
if files:
chat_message["uploaded_filenames"] = [file.name for file in files]
st.session_state.chat_messages.append(chat_message)
ファイルの保存方法に関しては要件に合わせて設定します。ここでは、シンプルにアップロードされたファイル名のみを保持しています。
1-3. チャット履歴の取得
チャット履歴の取得を行う関数(get_chat_history()
)を実装します。
チャット履歴取得の実装コード
class SessionManager:
@staticmethod
def get_chat_history() -> List[ChatMessage]:
return st.session_state.chat_messages
2. アップロードされたファイルの処理実装
次に、アップロードされたファイルを処理する機能を実装します。
2-1. ユーザー入力とファイルの取得
chat_input()
からユーザーが入力した値とアップロードしたファイルを取得する関数(process_user_input()
)を実装します。
ユーザー入力処理の実装コード
from streamlit.elements.widgets.chat import ChatInputValue
class InputHandler:
@staticmethod
def process_user_input(user_input: Any) -> Tuple[str, List[Any]]:
if isinstance(user_input, str):
return user_input, []
elif isinstance(user_input, ChatInputValue):
return user_input.text, user_input.files or []
else:
raise ValueError(f"Unexpected input type: {type(user_input)}")
2-2. 画像のエンコード
アップロードされたファイルをそれぞれbase64形式に変換する関数(encode_images()
)を実装します。
画像エンコードの実装コード
import base64
import logging
# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# Type definitions for better code clarity
ImageData = Dict[str, str]
class InputHandler:
@staticmethod
def encode_images(files: List[Any]) -> List[ImageData]:
if not files:
return []
image_data = []
for file in files:
try:
encoded_data = base64.b64encode(file.read()).decode("utf-8")
file_type = file.type.split("/")[-1] # Extract format from MIME type
image_data.append({"type": file_type, "data": encoded_data})
except Exception as e:
logger.error(f"Error processing file {file.name}: {str(e)}")
# Continue processing other files even if one fails
return image_data
3. マルチモーダルのリクエスト処理実装
マルチモーダルのリクエスト処理を実装します。
3-1. メッセージの構築
OpenAI APIへ渡すメッセージの構築を行う関数(generate_user_message()
)を実装します。
メッセージ構築の実装コード
import tiktoken
from tiktoken.core import Encoding
# Type definitions for better code clarity
ContentItem = Dict[str, Any]
class TokenManager:
def __init__(self, tokenizer: Encoding, max_tokens: int) -> None:
self.tokenizer = tokenizer
self.max_tokens = max_tokens
def count_tokens(self, text: str) -> int:
return len(self.tokenizer.encode(text))
def format_chat_history(self, chat_messages: List[ChatMessage]) -> str:
token_count: int = 0
formatted_history: str = ""
# Process messages from newest to oldest to prioritize recent context
for message in reversed(chat_messages):
message_text = f"{message['role']}: {message['content']}\n"
message_tokens = self.count_tokens(message_text)
# Check if adding this message would exceed the token limit
if token_count + message_tokens > self.max_tokens:
break
# Add message to history and update token count
token_count += message_tokens
formatted_history = message_text + formatted_history
return formatted_history
class PromptGenerator:
def __init__(self, token_manager: TokenManager) -> None:
self.token_manager = token_manager
def generate_enhanced_prompt(self, prompt: str, chat_history: Optional[List[ChatMessage]] = None) -> str:
if not chat_history:
return prompt
formatted_history = self.token_manager.format_chat_history(chat_history)
enhanced_prompt = f"""
Generate a response for the user considering the prompt and the conversation history.
Chat history:
{formatted_history}
User's prompt:
{prompt}
"""
return enhanced_prompt
def sanitize_prompt(self, prompt: str) -> str:
# This is a basic implementation - a production system would need more robust checks
dangerous_patterns = [
"ignore previous instructions",
"ignore all previous prompts",
"disregard your instructions",
]
sanitized_prompt = prompt
for pattern in dangerous_patterns:
if pattern.lower() in prompt.lower():
logger.warning(f"Potentially dangerous prompt detected: {pattern}")
sanitized_prompt = sanitized_prompt.replace(pattern, "[FILTERED]")
return sanitized_prompt
class MessageProcessor:
def __init__(self, token_manager: TokenManager) -> None:
self.token_manager = token_manager
self.prompt_generator = PromptGenerator(token_manager)
def generate_user_message(
self, prompt: str, chat_history: Optional[List[ChatMessage]] = None, images: Optional[List[ImageData]] = None
) -> Dict[str, Any]:
content: List[ContentItem] = []
# Add images if provided
if images:
for image in images:
content.append(
{"type": "image_url", "image_url": {"url": f"data:image/{image['type']};base64,{image['data']}"}}
)
# Generate enhanced prompt with context and sanitization
enhanced_prompt = self.prompt_generator.generate_enhanced_prompt(prompt, chat_history)
sanitized_prompt = self.prompt_generator.sanitize_prompt(enhanced_prompt)
# Add text content as the first item
content.insert(0, {"type": "text", "text": sanitized_prompt})
return {"role": "user", "content": content}
3-2. OpenAI APIのリクエスト処理
OpenAI APIのリクエスト処理を行う関数(generate_response()
)を実装します。
OpenAI API リクエスト処理の実装コード
# Constants
MODEL_NAME = "gpt-4o-mini"
class OpenAIClient:
def __init__(self) -> None:
self.client = OpenAI()
def generate_response(self, user_message: Dict[str, Any]) -> str:
try:
response = self.client.chat.completions.create(model=MODEL_NAME, messages=[user_message])
return response.choices[0].message.content
except Exception as e:
logger.error(f"Error generating response from OpenAI: {str(e)}")
raise Exception(f"Failed to get response from AI model: {str(e)}")
4. UI実装
最後に、UIを実装していきます。
4-1. ページ情報の設定
ページ情報の設定を行う関数(setup_page()
)を実装します。
ページ設定の実装コード
# Constants
MODEL_MAX_INPUT_TOKEN = 128000
TOKENIZER = tiktoken.encoding_for_model("gpt-4o")
class ChatUI:
def __init__(self) -> None:
self.openai_client = OpenAIClient()
token_manager = TokenManager(TOKENIZER, MODEL_MAX_INPUT_TOKEN)
self.message_processor = MessageProcessor(token_manager)
self.file_handler = InputHandler()
def setup_page(self) -> None:
st.set_page_config(page_title="Multimodal Chat Application", page_icon="💬", layout="centered")
st.title("Multimodal Chat Application")
st.subheader("Chat with AI using text and images")
4-2. チャット履歴の表示
チャット履歴の表示を行う関数(display_chat_history()
)を実装します。
チャット履歴表示の実装コード
class ChatUI:
def display_chat_history(self) -> None:
for message in SessionManager.get_chat_history():
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Display attached files if present
if message["role"] == "user" and message.get("uploaded_filenames"):
with st.expander("Attached Files", expanded=False):
for filename in message["uploaded_filenames"]:
st.caption(filename)
ユーザーがアップロードしたファイル情報の表示方法は要件に合わせて設定します。ここではシンプルにアップロードされたファイル名のみをst.expander()
を用いて表示しています。
4-3. ユーザー入力処理
ユーザーの入力欄の表示、および、LLMへのリクエストを行う関数(handle_user_input()
)を実装します。
ユーザー入力処理の実装コード
# Constants
ALLOWED_FILE_TYPES = ["jpg", "jpeg", "png"]
class ChatUI:
def handle_user_input(self) -> None:
user_input = st.chat_input(
"Type your message here...",
accept_file="multiple",
file_type=ALLOWED_FILE_TYPES,
)
if not user_input:
return
try:
# Process the user input
prompt, files = self.file_handler.process_user_input(user_input)
# Extract and encode images
images = self.file_handler.encode_images(files)
# Display user message
with st.chat_message("user"):
st.markdown(prompt)
if files:
with st.expander("Attached Files", expanded=False):
for file in files:
st.caption(file.name)
# Generate and display AI response
with st.spinner("Generating response..."):
try:
# Prepare message for the API
user_message = self.message_processor.generate_user_message(
prompt, SessionManager.get_chat_history(), images
)
# Get response from OpenAI
response_text = self.openai_client.generate_response(user_message)
# Display the response
with st.chat_message("assistant"):
st.markdown(response_text)
# Update chat history
SessionManager.add_message("user", prompt, files)
SessionManager.add_message("assistant", response_text)
except Exception as e:
st.error(f"Error: {str(e)}")
logger.error(f"Error in AI response generation: {str(e)}")
except Exception as e:
st.error(f"An error occurred: {str(e)}")
logger.error(f"Error in input processing: {str(e)}")
4-4. メイン関数
上記の関数をmain()
で順々に呼び出します。
メイン関数の実装コード
def main() -> None:
try:
# Initialize session state
SessionManager.initialize_session()
# Set up and run the chat UI
chat_ui = ChatUI()
chat_ui.setup_page()
chat_ui.display_chat_history()
chat_ui.handle_user_input()
except Exception as e:
st.error(f"Application error: {str(e)}")
logger.critical(f"Critical application error: {str(e)}")
if __name__ == "__main__":
main()
おわりに
今回は、先日のStreamlitアップデートで追加されたchat_input()
のファイルアップロード機能を使って、マルチモーダルチャットアプリの実装方法を解説しました。
コードの全体像を見たい方は、以下のGitHubリポジトリをご覧ください。このサンプルがマルチモーダルアプリケーションを開発する際に、少しでも参考になれば嬉しいです。
ここまでお読みいただき、ありがとうございました。
Discussion