❄️

Streamlit in Snowflake (SiS) で画像検索を実現しよう Part3 -画像のベクトル検索機能を追加する-

2024/09/24に公開

はじめに

本記事は Part3 となります。以下 Part1 と Part2 の続きの記事となりますので、まだお読みでない方は最初に以下をご覧ください。(本記事に完成版のコードがありますので、慣れている方は Part3 だけ読んでいただいて問題ございません)

https://zenn.dev/tsubasa_tech/articles/1e6dd562777481

https://zenn.dev/tsubasa_tech/articles/b6b9928d33ae58

Part1 では Streamlit in Snowflake でアプリのデフォルト内部ステージに格納した画像を表示する画像ギャラリーアプリを作成し、Part2 では画像ギャラリーアプリをベースに、各画像のキャプションを生成する機能を追加してきました。

そしていよいよ Part3 では画像検索機能を追加しアプリケーションを完成させていきます!画像検索にはベクトル検索を用いることで曖昧なキーワードでも妥当な検索結果が出るようにしていきたいと思います。

機能概要

実現したいこと

  • (済) Streamlit in Snowflake で画像データを表示する
  • (済) Streamlit in Snowflake で画像に説明文を追加する
  • ★画像の説明文を元にベクトルデータを生成する
  • ★Streamlit in Snowflake で画像検索を行う

★:Part3 で実現する範囲

Part3 で実装する機能一覧

  • 画像キャプションからベクトルデータを生成する機能
  • 画像ギャラリーから画像の曖昧検索をする機能

Part3 の完成イメージ


Part1で作成したイメージギャラリー


Part2で作成した画像キャプションの編集機能


Part2で作成した画像キャプションの自動生成機能


Part2で作成した画像キャプションの自動生成結果


Part2で作成した画像キャプションの一括生成機能


ベクトルデータ自動生成機能


ベクトル検索を用いた画像のあいまい検索機能

前提条件

  • Snowflake
    • Snowflake アカウントがあること
    • Streamlit in Snowflake のインストールパッケージ
      • boto3 1.28.64
  • AWS
    • Amazon Bedrock が利用できる AWS アカウント (今回の手順では Claude 3.5 Sonnet を利用するため us-east-1 リージョンを利用しています)

基本の確認

Snowflake のベクトル化の選択肢

今回は画像のキャプションからベクトルデータを作ることで検索機能を実装していきます。ベクトル化は概念や実装が難しいイメージがあると思いますが、Snowflake では簡単にベクトル化とベクトル検索を実装することができます。是非今回の記事で「ベクトル検索って思ったより簡単だし使えそう!」と思っていただけれたらとても嬉しいです。

Snowflake のベクトル化の方法や性能については別途記事を書いておりますので詳細はそちらでご確認ください。

https://zenn.dev/tsubasa_tech/articles/c0a2b8793a5d1f

手順

(省略) Streamlit in Snowflake のアプリを作成し画像をアップロードする

まだ実施していない場合はまず Part1 の記事を実施してください。

(省略) Streamlit in Snowflake のアプリから Amazon Bedrock にアクセスできるようにする

画像に自動的にキャプションを入れるには以下に挙げるように多くの選択肢が考えられます。

  1. 画像処理系の Python ライブラリで処理を実装する
  2. 画像認識系の ML モデルを作成し画像にキャプションを生成する
  3. BLIP-2 などの既存の 画像用 AI モデルを使って画像にキャプションを生成する
  4. マルチモーダル系の GenAI に画像を渡して画像のキャプションを生成する
  5. キャプション生成の SaaS サービスを利用する

今回は以前の記事で Amazon Bedrock に接続する方法を紹介しているため、4のマルチモーダル系の GenAI として Amazon Bedrock の anthropic.claude-3-5-sonnet を使って画像のキャプションを生成してみたいと思います。

以下記事の手順『Streamlit オブジェクトに外部アクセス統合を紐付ける』までを参考に実施してください。

https://zenn.dev/tsubasa_tech/articles/ea53b5e37705cb

Streamlit in Snowflake のアプリを実行

Streamlit in Snowflake アプリの編集画面で以下コードをコピー&ペーストで貼り付けて完了です。

import streamlit as st
import pandas as pd
import os
import base64
import boto3
import json
from snowflake.snowpark.context import get_active_session
from snowflake.snowpark.functions import col, when_matched, when_not_matched, lit, call_udf
import _snowflake
from PIL import Image
import io

# カスタムテーマの設定
st.set_page_config(
    page_title="Image Gallery",
    layout="wide",
    initial_sidebar_state="expanded",
)

# カスタムCSSの追加
st.markdown("""
<style>
    .reportview-container {
        background: #f0f2f6;
    }
    .main .block-container {
        padding-top: 2rem;
        padding-bottom: 2rem;
        padding-left: 5rem;
        padding-right: 5rem;
    }
    .stButton>button {
        background-color: #4CAF50;
        color: white;
        padding: 10px 20px;
        border: none;
        border-radius: 5px;
        cursor: pointer;
        transition: background-color 0.3s;
    }
    .stButton>button:hover {
        background-color: #45a049;
    }
    .stTextInput>div>div>input {
        border-radius: 5px;
    }
    .stSelectbox>div>div>select {
        border-radius: 5px;
    }
    h1, h2, h3 {
        color: #2c3e50;
    }
    .stProgress > div > div > div > div {
        background-color: #4CAF50;
    }
</style>
""", unsafe_allow_html=True)

# 画像フォルダのパス
IMAGE_FOLDER = "image"

# Snowflakeセッションの取得
session = get_active_session()

# テーブルの作成(初回起動時のみ)
@st.cache_resource
def create_table_if_not_exists():
    session.sql("""
    CREATE TABLE IF NOT EXISTS IMAGE_METADATA (
        FILE_NAME STRING,
        DESCRIPTION STRING,
        VECTOR VECTOR(FLOAT, 1024)
    )
    """).collect()

create_table_if_not_exists()

# AWSの認証情報を取得する関数
def get_aws_credentials():
    aws_key_object = _snowflake.get_username_password('bedrock_key')
    region = 'us-east-1'
    return {
        'aws_access_key_id': aws_key_object.username,
        'aws_secret_access_key': aws_key_object.password,
        'region_name': region
    }, region

# Bedrockクライアントの設定
boto3_session_args, region = get_aws_credentials()
boto3_session = boto3.Session(**boto3_session_args)
bedrock = boto3_session.client('bedrock-runtime', region_name=region)

# 画像データの取得
@st.cache_data
def get_image_data():
    image_files = [f for f in os.listdir(IMAGE_FOLDER) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif'))]
    return [{"FILE_NAME": f, "IMG_PATH": os.path.join(IMAGE_FOLDER, f)} for f in image_files]

# メタデータの取得
@st.cache_data
def get_metadata():
    return session.table("IMAGE_METADATA").select("FILE_NAME", "DESCRIPTION").to_pandas()

# 画像をサムネイルに変換してbase64エンコード
@st.cache_data
def get_thumbnail_base64(img_path, max_size=(300, 300)):
    with Image.open(img_path) as img:
        img.thumbnail(max_size)
        buffered = io.BytesIO()
        img.save(buffered, format="JPEG")
        return base64.b64encode(buffered.getvalue()).decode('utf-8')

# 画像データとメタデータの初期化
if 'img_df' not in st.session_state:
    st.session_state.img_df = get_image_data()
if 'metadata_df' not in st.session_state:
    st.session_state.metadata_df = get_metadata()

# 画像ギャラリーの表示
def show_image_gallery():
    st.title("🖼️ Image Gallery")

    # 検索ボックスの追加
    search_query = st.text_input("画像を検索 (関連度の高い上位10件が表示されます)", "")

    if search_query:
        # 検索クエリのエスケープ処理(基本的なSQLインジェクション対策)
        escaped_query = search_query.replace("'", "''")
        
        # 検索クエリのベクトル化と類似度計算
        search_results = session.sql(f"""
        WITH search_vector AS (
            SELECT SNOWFLAKE.CORTEX.EMBED_TEXT_1024('voyage-multilingual-2', '{escaped_query}') as embedding
        )
        SELECT 
            i.FILE_NAME, 
            i.DESCRIPTION, 
            VECTOR_COSINE_SIMILARITY(i.VECTOR, s.embedding) as similarity
        FROM 
            IMAGE_METADATA i, 
            search_vector s
        WHERE 
            i.VECTOR IS NOT NULL
        ORDER BY 
            similarity DESC
        LIMIT 10
        """).collect()

        # 検索結果の表示
        st.subheader("検索結果")
        for result in search_results:
            file_name = result['FILE_NAME']
            description = result['DESCRIPTION']
            similarity = result['SIMILARITY']
            
            img_path = next((img['IMG_PATH'] for img in st.session_state.img_df if img['FILE_NAME'] == file_name), None)
            if img_path:
                col1, col2 = st.columns([1, 3])
                with col1:
                    st.image(img_path, width=150)
                with col2:
                    st.write(f"ファイル名: {file_name}")
                    st.write(f"説明: {description}")
                    st.write(f"一致率: {similarity:.1%}")
                st.markdown("---")
    else:
        # 通常のギャラリー表示
        num_columns = st.slider("Width:", min_value=1, max_value=5, value=4)
        cols = st.columns(num_columns)
        for i, img in enumerate(st.session_state.img_df):
            with cols[i % num_columns]:
                st.image(img["IMG_PATH"], caption=None, use_column_width=True)

# 画像の説明編集
def edit_image_descriptions():
    st.title("✏️ 画像キャプションの編集")
    st.session_state.metadata_df = get_metadata()

    # 新しい画像をメタデータに追加
    for img in st.session_state.img_df:
        if img["FILE_NAME"] not in st.session_state.metadata_df["FILE_NAME"].values:
            new_row = pd.DataFrame({"FILE_NAME": [img["FILE_NAME"]], "DESCRIPTION": [""]})
            st.session_state.metadata_df = pd.concat([st.session_state.metadata_df, new_row], ignore_index=True)

    merged_df = pd.merge(st.session_state.metadata_df, pd.DataFrame(st.session_state.img_df), on="FILE_NAME", how="left")

    with st.form("edit_descriptions"):
        for _, row in merged_df.iterrows():
            col1, col2 = st.columns([1, 3])
            with col1:
                st.image(row["IMG_PATH"], width=100)
            with col2:
                new_description = st.text_input(f"ファイル名: {row['FILE_NAME']}", value=row["DESCRIPTION"], key=row['FILE_NAME'])
                merged_df.loc[merged_df["FILE_NAME"] == row["FILE_NAME"], "DESCRIPTION"] = new_description

        submit_button = st.form_submit_button("変更を保存")

    if submit_button:
        update_snowflake_table(merged_df[['FILE_NAME', 'DESCRIPTION']])
        st.success("Changes saved successfully!")
        st.cache_data.clear()
        st.session_state.metadata_df = get_metadata()

# 画像の説明を生成する関数
def generate_description(image_path):
    image_base64 = get_thumbnail_base64(image_path)
    prompt = """
    この画像について日本語で400文字以内で、一行で説明してください。
    返事は不要で出力される内容は画像の説明文だけにしてください。
    """
    
    request_body = {
        "anthropic_version": "bedrock-2023-05-31",
        "max_tokens": 200000,
        "messages": [
            {
                "role": "user",
                "content": [
                    {
                        "type": "image",
                        "source": {
                            "type": "base64",
                            "media_type": "image/jpeg",
                            "data": image_base64
                        }
                    },
                    {
                        "type": "text",
                        "text": prompt
                    }
                ]
            }
        ]
    }
    
    response = bedrock.invoke_model(
        body=json.dumps(request_body),
        modelId="anthropic.claude-3-5-sonnet-20240620-v1:0",
        accept='application/json',
        contentType='application/json'
    )
    
    response_body = json.loads(response.get('body').read())
    return response_body["content"][0]["text"]

# Snowflakeテーブルを更新する関数
def update_snowflake_table(update_df):
    snow_df = session.create_dataframe(update_df)
    
    session.table("IMAGE_METADATA").merge(
        snow_df,
        (session.table("IMAGE_METADATA").FILE_NAME == snow_df.FILE_NAME),
        [
            when_matched().update({
                "DESCRIPTION": snow_df.DESCRIPTION
            }),
            when_not_matched().insert({
                "FILE_NAME": snow_df.FILE_NAME,
                "DESCRIPTION": snow_df.DESCRIPTION
            })
        ]
    )

# 画像の説明生成
def generate_image_descriptions():
    st.title("🤖 画像キャプションの自動生成")

    if 'generated_description' not in st.session_state:
        st.session_state.generated_description = None
    if 'selected_image' not in st.session_state:
        st.session_state.selected_image = None

    # 個別の画像に対する説明生成
    with st.form("generate_description"):
        selected_image = st.selectbox("対象の画像を選択してください:", options=[img["FILE_NAME"] for img in st.session_state.img_df])
        generate_button = st.form_submit_button("画像キャプションの生成")

    if generate_button:
        image_info = next(img for img in st.session_state.img_df if img['FILE_NAME'] == selected_image)
        generated_description = generate_description(image_info['IMG_PATH'])

        st.session_state.generated_description = generated_description
        st.session_state.selected_image = selected_image

        st.image(image_info['IMG_PATH'], width=300)
        st.write("生成されたキャプション:")
        st.write(generated_description)

    if st.session_state.generated_description is not None:
        if st.button("キャプションを保存"):
            update_snowflake_table(pd.DataFrame({'FILE_NAME': [st.session_state.selected_image], 'DESCRIPTION': [st.session_state.generated_description]}))
            st.success("キャプションが正常に保存されました")
            st.cache_data.clear()
            st.session_state.metadata_df = get_metadata()
            
            st.session_state.generated_description = None
            st.session_state.selected_image = None

    # 説明がない画像をまとめて処理
    st.subheader("キャプション未設定の画像の一括キャプション生成")
    
    images_without_description = [
        img for img in st.session_state.img_df 
        if img["FILE_NAME"] not in st.session_state.metadata_df[
            st.session_state.metadata_df["DESCRIPTION"].notna() & 
            (st.session_state.metadata_df["DESCRIPTION"] != "")
        ]["FILE_NAME"].values
    ]
    
    if images_without_description:
        st.write(f"{len(images_without_description)}枚の画像にキャプションが設定されていません。")
        if st.button("画像キャプションを一括生成"):
            progress_bar = st.progress(0)
            for i, img in enumerate(images_without_description):
                generated_description = generate_description(img['IMG_PATH'])
                update_snowflake_table(pd.DataFrame({'FILE_NAME': [img['FILE_NAME']], 'DESCRIPTION': [generated_description]}))
                progress_bar.progress((i + 1) / len(images_without_description))

            st.success("すべての画像のキャプションが生成され、保存されました!")
            st.cache_data.clear()
            st.session_state.metadata_df = get_metadata()
    else:
        st.write("すべての画像にキャプションが設定されています。")

    # デバッグ情報の表示
    st.subheader("メタデータ情報")
    st.write(st.session_state.metadata_df)

# Cortex LLMのEmbedding関数を使用してベクトルデータを生成する関数
def generate_embedding(text):
    if text and text.strip():
        result = session.sql(f"SELECT SNOWFLAKE.CORTEX.EMBED_TEXT_1024('voyage-multilingual-2', '{text}') as embedding").collect()
        return result[0]['EMBEDDING']
    return None

# ベクトルデータを生成して保存する関数
def generate_and_save_vectors():
    st.title("🧬 ベクトルデータの自動生成")

    # メタデータを取得(ベクトルデータの情報も含める)
    full_metadata = session.table("IMAGE_METADATA").select("FILE_NAME", "DESCRIPTION", "VECTOR").to_pandas()

    # ベクトルデータが未生成の画像を抽出
    images_without_vector = full_metadata[
        (full_metadata['DESCRIPTION'].notna()) & 
        (full_metadata['DESCRIPTION'] != "") & 
        (full_metadata['VECTOR'].isna())  # ベクトルデータが存在しない行のみ
    ]

    if images_without_vector.empty:
        st.write("すべての画像にベクトルデータが設定されています。")
    else:
        st.write(f"{len(images_without_vector)}枚の画像にベクトルデータを生成できます。")
        if st.button("ベクトルデータを生成"):
            progress_bar = st.progress(0)
            for i, (_, row) in enumerate(images_without_vector.iterrows()):
                # SQLを直接実行してベクトルを生成し保存
                session.sql(f"""
                UPDATE IMAGE_METADATA
                SET VECTOR = SNOWFLAKE.CORTEX.EMBED_TEXT_1024('voyage-multilingual-2', '{row['DESCRIPTION']}')
                WHERE FILE_NAME = '{row['FILE_NAME']}' AND VECTOR IS NULL
                """).collect()
                progress_bar.progress((i + 1) / len(images_without_vector))

            st.success("すべての対象画像のベクトルデータが生成され、保存されました!")
            st.cache_data.clear()

    # デバッグ情報の表示
    st.subheader("メタデータ情報")
    updated_full_metadata = session.table("IMAGE_METADATA").select("FILE_NAME", "DESCRIPTION", "VECTOR").to_pandas()
    st.write(updated_full_metadata)

# メインのアプリケーション実行部分
if __name__ == "__main__":
    st.sidebar.title("Navigation")
    page = st.sidebar.radio(
        "使用する機能を選択してください:", 
        ["Image Gallery", "キャプションの編集", "キャプションの自動生成", "ベクトルデータの自動生成"]
    )

    if page == "Image Gallery":
        show_image_gallery()
    elif page == "キャプションの編集":
        edit_image_descriptions()
    elif page == "キャプションの自動生成":
        generate_image_descriptions()
    elif page == "ベクトルデータの自動生成":
        generate_and_save_vectors()

一部コードの解説

以下の部分でカスタム CSS を適用し少しデザインをカスタムしています。Streamlit in Snowflake では HTML CSS JavaScript で Web アプリケーションのカスタムを行うことが可能です。詳細はこちらのドキュメントをご確認ください。

# カスタムCSSの追加
st.markdown("""
<style>
    .reportview-container {
        background: #f0f2f6;
    }
    .main .block-container {
        padding-top: 2rem;
        padding-bottom: 2rem;
        padding-left: 5rem;
        padding-right: 5rem;
    }
    .stButton>button {
        background-color: #4CAF50;
        color: white;
        padding: 10px 20px;
        border: none;
        border-radius: 5px;
        cursor: pointer;
        transition: background-color 0.3s;
    }
    .stButton>button:hover {
        background-color: #45a049;
    }
    .stTextInput>div>div>input {
        border-radius: 5px;
    }
    .stSelectbox>div>div>select {
        border-radius: 5px;
    }
    h1, h2, h3 {
        color: #2c3e50;
    }
    .stProgress > div > div > div > div {
        background-color: #4CAF50;
    }
</style>
""", unsafe_allow_html=True)

以下の部分でユーザーの検索文字列をベクトル化し、画像のベクトルデータとの類似度を計算しています。この類似度が高いほど検索文字列と画像が近いと言えるため、類似度の上位10件を取得するようにしています。

        # 検索クエリのベクトル化と類似度計算
        search_results = session.sql(f"""
        WITH search_vector AS (
            SELECT SNOWFLAKE.CORTEX.EMBED_TEXT_1024('voyage-multilingual-2', '{escaped_query}') as embedding
        )
        SELECT 
            i.FILE_NAME, 
            i.DESCRIPTION, 
            VECTOR_COSINE_SIMILARITY(i.VECTOR, s.embedding) as similarity
        FROM 
            IMAGE_METADATA i, 
            search_vector s
        WHERE 
            i.VECTOR IS NOT NULL
        ORDER BY 
            similarity DESC
        LIMIT 10
        """).collect()

以下の部分で画像のキャプションからベクトルデータを生成しています。3行のシンプルな SQL クエリで実現できている点がポイントです。

        if st.button("ベクトルデータを生成"):
            progress_bar = st.progress(0)
            for i, (_, row) in enumerate(images_without_vector.iterrows()):
                # SQLを直接実行してベクトルを生成し保存
                session.sql(f"""
                UPDATE IMAGE_METADATA
                SET VECTOR = SNOWFLAKE.CORTEX.EMBED_TEXT_1024('voyage-multilingual-2', '{row['DESCRIPTION']}')
                WHERE FILE_NAME = '{row['FILE_NAME']}' AND VECTOR IS NULL
                """).collect()
                progress_bar.progress((i + 1) / len(images_without_vector))

            st.success("すべての対象画像のベクトルデータが生成され、保存されました!")
            st.cache_data.clear()

最後に

前回までで画像にキャプションを生成し画像活用する土台が整い、今回の記事で画像をあいまいな検索キーワードからでも検索できるようになりました。Snowflake の強力なベクトル検索の仕組みを用いているため、膨大な画像データがある場合でも瞬時に検索結果を取得することが可能です。

また本アプリの発展アイディアとしては以下のようなことも考えられます。

  • ユーザーの検索文字列だけではなく、指定した画像の類似画像を検索できるようにする
  • 画像だけではなくドキュメントや音楽など他の非構造化データも検索できるようにする

皆様も色んなアイディアが頭の中にあるのではないでしょうか?是非本記事を参考にしていただき、皆様のアイディアを実現してみていただければとても嬉しく思います。

宣伝

X で Snowflake の What's new の配信してます

X で Snowflake の What's new の更新情報を配信しておりますので、是非お気軽にフォローしていただければ嬉しいです。

日本語版

Snowflake の What's New Bot (日本語版)
https://x.com/snow_new_jp

English Version

Snowflake What's New Bot (English Version)
https://x.com/snow_new_en

変更履歴

(20240924) 新規投稿

Discussion