❄️

Streamlit in Snowflake (SiS) で画像検索を実現しよう Part2 -画像のキャプションを生成する-

2024/09/23に公開

はじめに

本記事は Part2 となります。以下 Part1 の続きの記事となりますので、まだお読みでない方は最初に以下をご覧ください。

https://zenn.dev/tsubasa_tech/articles/1e6dd562777481

Part1 では Streamlit in Snowflake でアプリのデフォルト内部ステージに格納した画像を表示する画像ギャラリーアプリを作成しました。Part2 では画像ギャラリーアプリをベースに、各画像のキャプションを生成し非構造化データを活用しやすくしていきます。

機能概要

実現したいこと

  • (済) Streamlit in Snowflake で画像データを表示する
  • ★Streamlit in Snowflake で画像に説明文を追加する
  • 画像の説明文を元にベクトルデータを生成する
  • Streamlit in Snowflake で画像検索を行う

★:Part2で実現する範囲

Part2 で実装する機能一覧

  • 画像のキャプションを手動で作成したり編集する機能
  • 画像のキャプションを自動で個別に作成する機能
  • キャプションが無い画像について一括でキャプションを生成する機能

Part2 の完成イメージ


Part1で作成したイメージギャラリー部分


画像のキャプションを編集する画面


画像のキャプションを自動生成する画面


生成 AI が作成したキャプション


キャプションが無い画像に一括でキャプションを生成する機能

前提条件

  • Snowflake
    • Snowflake アカウントがあること
    • Streamlit in Snowflake のインストールパッケージ
      • boto3 1.28.64
  • AWS
    • Amazon Bedrock が利用できる AWS アカウント (今回の手順では Claude 3.5 Sonnet を利用するため us-east-1 リージョンを利用しています)

手順

(省略) Streamlit in Snowflake のアプリを作成し画像をアップロードする

まだ実施していない場合はまず Part1 の記事を実施してください。

(省略) Streamlit in Snowflake のアプリから Amazon Bedrock にアクセスできるようにする

画像に自動的にキャプションを入れるには以下に挙げるように多くの選択肢が考えられます。

  1. 画像処理系の Python ライブラリで処理を実装する
  2. 画像認識系の ML モデルを作成し画像にキャプションを生成する
  3. BLIP-2 などの既存の 画像用 AI モデルを使って画像にキャプションを生成する
  4. マルチモーダル系の GenAI に画像を渡して画像のキャプションを生成する
  5. キャプション生成の SaaS サービスを利用する

今回は以前の記事で Amazon Bedrock に接続する方法を紹介しているため、4のマルチモーダル系の GenAI として Amazon Bedrock の anthropic.claude-3-5-sonnet を使って画像のキャプションを生成してみたいと思います。

以下記事の手順『Streamlit オブジェクトに外部アクセス統合を紐付ける』までを参考に実施してください。

https://zenn.dev/tsubasa_tech/articles/ea53b5e37705cb

Streamlit in Snowflake のアプリを実行

コードが少々長くなってしまいましたが Streamlit in Snowflake アプリの編集画面で以下コードをコピー&ペーストで貼り付けて完了です。

import streamlit as st
import pandas as pd
import os
import base64
import boto3
import json
from snowflake.snowpark.context import get_active_session
from snowflake.snowpark.functions import col, when_matched, when_not_matched
import _snowflake
from PIL import Image
import io

# Streamlitのページ設定
st.set_page_config(layout="wide", page_title="Image Gallery")

# 画像フォルダのパス
IMAGE_FOLDER = "image"

# Snowflakeセッションの取得
session = get_active_session()

# テーブルの作成(初回起動時のみ)
@st.cache_resource
def create_table_if_not_exists():
    session.sql("""
    CREATE TABLE IF NOT EXISTS IMAGE_METADATA (
        FILE_NAME STRING,
        DESCRIPTION STRING,
        VECTOR VECTOR(FLOAT, 1024)
    )
    """).collect()

create_table_if_not_exists()

# AWSの認証情報を取得する関数
def get_aws_credentials():
    aws_key_object = _snowflake.get_username_password('bedrock_key')
    region = 'us-east-1'
    return {
        'aws_access_key_id': aws_key_object.username,
        'aws_secret_access_key': aws_key_object.password,
        'region_name': region
    }, region

# Bedrockクライアントの設定
boto3_session_args, region = get_aws_credentials()
boto3_session = boto3.Session(**boto3_session_args)
bedrock = boto3_session.client('bedrock-runtime', region_name=region)

# 画像データの取得
@st.cache_data
def get_image_data():
    image_files = [f for f in os.listdir(IMAGE_FOLDER) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif'))]
    return [{"FILE_NAME": f, "IMG_PATH": os.path.join(IMAGE_FOLDER, f)} for f in image_files]

# メタデータの取得
@st.cache_data
def get_metadata():
    return session.table("IMAGE_METADATA").select("FILE_NAME", "DESCRIPTION").to_pandas()

# 画像をサムネイルに変換してbase64エンコード
@st.cache_data
def get_thumbnail_base64(img_path, max_size=(300, 300)):
    with Image.open(img_path) as img:
        img.thumbnail(max_size)
        buffered = io.BytesIO()
        img.save(buffered, format="JPEG")
        return base64.b64encode(buffered.getvalue()).decode('utf-8')

# 画像データとメタデータの初期化
if 'img_df' not in st.session_state:
    st.session_state.img_df = get_image_data()
if 'metadata_df' not in st.session_state:
    st.session_state.metadata_df = get_metadata()

# 画像ギャラリーの表示
def show_image_gallery():
    st.title("Image Gallery")
    num_columns = st.slider("Width:", min_value=1, max_value=5, value=4)
    cols = st.columns(num_columns)
    for i, img in enumerate(st.session_state.img_df):
        with cols[i % num_columns]:
            st.image(img["IMG_PATH"], caption=None, use_column_width=True)

# 画像の説明編集
def edit_image_descriptions():
    st.title("Edit Image Descriptions")
    st.session_state.metadata_df = get_metadata()

    # 新しい画像をメタデータに追加
    for img in st.session_state.img_df:
        if img["FILE_NAME"] not in st.session_state.metadata_df["FILE_NAME"].values:
            new_row = pd.DataFrame({"FILE_NAME": [img["FILE_NAME"]], "DESCRIPTION": [""]})
            st.session_state.metadata_df = pd.concat([st.session_state.metadata_df, new_row], ignore_index=True)

    merged_df = pd.merge(st.session_state.metadata_df, pd.DataFrame(st.session_state.img_df), on="FILE_NAME", how="left")

    with st.form("edit_descriptions"):
        for _, row in merged_df.iterrows():
            col1, col2 = st.columns([1, 3])
            with col1:
                st.image(row["IMG_PATH"], width=100)
            with col2:
                new_description = st.text_input(f"Description for {row['FILE_NAME']}", value=row["DESCRIPTION"], key=row['FILE_NAME'])
                merged_df.loc[merged_df["FILE_NAME"] == row["FILE_NAME"], "DESCRIPTION"] = new_description

        submit_button = st.form_submit_button("Save Changes")

    if submit_button:
        update_snowflake_table(merged_df[['FILE_NAME', 'DESCRIPTION']])
        st.success("Changes saved successfully!")
        st.cache_data.clear()
        st.session_state.metadata_df = get_metadata()

# 画像の説明を生成する関数
def generate_description(image_path):
    image_base64 = get_thumbnail_base64(image_path)
    prompt = """
    この画像について日本語で400文字以内で、一行で説明してください。
    返事は不要で出力される内容は画像の説明文だけにしてください。
    """
    
    request_body = {
        "anthropic_version": "bedrock-2023-05-31",
        "max_tokens": 200000,
        "messages": [
            {
                "role": "user",
                "content": [
                    {
                        "type": "image",
                        "source": {
                            "type": "base64",
                            "media_type": "image/jpeg",
                            "data": image_base64
                        }
                    },
                    {
                        "type": "text",
                        "text": prompt
                    }
                ]
            }
        ]
    }
    
    response = bedrock.invoke_model(
        body=json.dumps(request_body),
        modelId="anthropic.claude-3-5-sonnet-20240620-v1:0",
        accept='application/json',
        contentType='application/json'
    )
    
    response_body = json.loads(response.get('body').read())
    return response_body["content"][0]["text"]

# Snowflakeテーブルを更新する関数
def update_snowflake_table(update_df):
    snow_df = session.create_dataframe(update_df)
    
    session.table("IMAGE_METADATA").merge(
        snow_df,
        (session.table("IMAGE_METADATA").FILE_NAME == snow_df.FILE_NAME),
        [
            when_matched().update({
                "DESCRIPTION": snow_df.DESCRIPTION
            }),
            when_not_matched().insert({
                "FILE_NAME": snow_df.FILE_NAME,
                "DESCRIPTION": snow_df.DESCRIPTION
            })
        ]
    )

# 画像の説明生成
def generate_image_descriptions():
    st.title("画像の説明を生成")

    if 'generated_description' not in st.session_state:
        st.session_state.generated_description = None
    if 'selected_image' not in st.session_state:
        st.session_state.selected_image = None

    # 個別の画像に対する説明生成
    with st.form("generate_description"):
        selected_image = st.selectbox("画像を選択してください", options=[img["FILE_NAME"] for img in st.session_state.img_df])
        generate_button = st.form_submit_button("説明を生成")

    if generate_button:
        image_info = next(img for img in st.session_state.img_df if img['FILE_NAME'] == selected_image)
        generated_description = generate_description(image_info['IMG_PATH'])

        st.session_state.generated_description = generated_description
        st.session_state.selected_image = selected_image

        st.image(image_info['IMG_PATH'], width=300)
        st.write("生成された説明:")
        st.write(generated_description)

    if st.session_state.generated_description is not None:
        if st.button("説明を保存"):
            update_snowflake_table(pd.DataFrame({'FILE_NAME': [st.session_state.selected_image], 'DESCRIPTION': [st.session_state.generated_description]}))
            st.success("説明が正常に保存されました!")
            st.cache_data.clear()
            st.session_state.metadata_df = get_metadata()
            
            st.session_state.generated_description = None
            st.session_state.selected_image = None

    # 説明がない画像をまとめて処理
    st.subheader("説明が未設定の画像の一括処理")
    
    images_without_description = [
        img for img in st.session_state.img_df 
        if img["FILE_NAME"] not in st.session_state.metadata_df[
            st.session_state.metadata_df["DESCRIPTION"].notna() & 
            (st.session_state.metadata_df["DESCRIPTION"] != "")
        ]["FILE_NAME"].values
    ]
    
    if images_without_description:
        st.write(f"{len(images_without_description)}枚の画像に説明が設定されていません。")
        if st.button("一括で説明を生成"):
            progress_bar = st.progress(0)
            for i, img in enumerate(images_without_description):
                generated_description = generate_description(img['IMG_PATH'])
                update_snowflake_table(pd.DataFrame({'FILE_NAME': [img['FILE_NAME']], 'DESCRIPTION': [generated_description]}))
                progress_bar.progress((i + 1) / len(images_without_description))

            st.success("すべての画像の説明が生成され、保存されました!")
            st.cache_data.clear()
            st.session_state.metadata_df = get_metadata()
    else:
        st.write("すべての画像に説明が設定されています。")

    # デバッグ情報の表示
    st.subheader("メタデータ情報")
    st.write(st.session_state.metadata_df)

# メインのアプリケーション実行部分
if __name__ == "__main__":
    page = st.sidebar.selectbox(
        "Choose a page", 
        ["Image Gallery", "Edit Descriptions", "Generate Descriptions"]
    )

    if page == "Image Gallery":
        show_image_gallery()
    elif page == "Edit Descriptions":
        edit_image_descriptions()
    elif page == "Generate Descriptions":
        generate_image_descriptions()

一部コードの解説

以下の部分で画像のメタデータを保存するためのテーブルを作成しています。3カラム目の VECTOR は次回 Part3 でベクトルデータを格納する予定のカラムです。

# テーブルの作成(初回起動時のみ)
@st.cache_resource
def create_table_if_not_exists():
    session.sql("""
    CREATE TABLE IF NOT EXISTS IMAGE_METADATA (
        FILE_NAME STRING,
        DESCRIPTION STRING,
        VECTOR VECTOR(FLOAT, 1024)
    )
    """).collect()

create_table_if_not_exists()

以下の部分は Amazon Bedrock に画像データを渡す時に画像を縮小して Base64 にエンコードする機能です。なるべく画像を小さくしてキャプション生成のパフォーマンスを高める狙いと、Amazon Bedrock では恐らく直接画像のバイナリデータを渡すことができないためエンコーディングの処理を入れています。

# 画像をサムネイルに変換してbase64エンコード
@st.cache_data
def get_thumbnail_base64(img_path, max_size=(300, 300)):
    with Image.open(img_path) as img:
        img.thumbnail(max_size)
        buffered = io.BytesIO()
        img.save(buffered, format="JPEG")
        return base64.b64encode(buffered.getvalue()).decode('utf-8')

以下の部分で Amazon Bedrock にキャプション生成のプロンプトを渡しています。次の Part3 の検索精度を高めるためにこのプロンプトの調整が必要になる可能性があります。

# 画像の説明を生成する関数
def generate_description(image_path):
    image_base64 = get_thumbnail_base64(image_path)
    prompt = """
    この画像について日本語で400文字以内で、一行で説明してください。
    返事は不要で出力される内容は画像の説明文だけにしてください。
    """

最後に

前回までは単なる画像ギャラリーでしたが、今回の改良により画像のキャプションを自動的に生成できるようになりました。今回の取り組みにおけるゴールとしては画像検索を行えるようにすることではありますが、画像に紐づいたキャプションがあるとアイディア次第で画像活用の幅がグンと広がると思います。

次回 Part3 は、いよいよ画像検索の実装です。画像のキャプションを利用してベクトル化を行い、ベクトル検索できるようにしていきますので楽しみにしていてください。

X で Snowflake の What's new の配信してます

X で Snowflake の What's new の更新情報を配信しておりますので、是非お気軽にフォローしていただければ嬉しいです。

日本語版

Snowflake の What's New Bot (日本語版)
https://x.com/snow_new_jp

English Version

Snowflake What's New Bot (English Version)
https://x.com/snow_new_en

変更履歴

(20240923) 新規投稿

Discussion