❄️

Cortex AnalystでText to SQLのStreamlitアプリを作成する

2024/10/01に公開

はじめに

Cortex AnalystはLLMを用いて自然言語の入力からSQLのクエリを出力するText to SQLの機能をREST APIで提供しており、Streamlitなどに組み込んで使うことができます。これを使うとSnowflake内にあるデータに対してText to SQLで分析することができます。

https://docs.snowflake.com/user-guide/snowflake-cortex/cortex-analyst

公式のチュートリアルを参考にして、以下のステップでアプリを作成しました。

  • セマンティックモデルを作成する
  • セマンティックモデルをステージにアップロードする
  • スタンドアロンのStreamlitアプリを作成して実行する
  • Streamlitアプリを操作する

参考にしたチュートリアル↓
https://docs.snowflake.com/user-guide/snowflake-cortex/cortex-analyst/tutorials/tutorial-1

今回はStreamlitアプリ作成の部分の話をします。Cortex Analystの概要、セマンティックモデルの作成については↓の記事に書いたので気になる方は見てみてください。
https://zenn.dev/tree_and_tree/articles/fe6f261748b560

Cortex AnalystのRSET APIの仕様

https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-analyst/rest-api

アプリ作成の前にREST APIの仕様を見てみます。
まず、APIはステートレスで、単一ターンの会話のみがサポートされているみたいです。外側で上手く機能を盛り込めば会話履歴を使ったり複数ターンの会話も実現できそうでしたが、今回はとりあえず単一ターンで実装しました。

request body

user、ユーザーの質問文、セマンティックモデルのyamlファイルが含まれます。

  • messages[].role:userのみサポート
  • messages[].content[]
    • type:コンテンツタイプ。textのみサポート
    • text:ユーザーの質問
  • semantic_model_file:セマンティックモデルファイルのパス

request bodyの例

{
    "messages": [
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": "which company had the most revenue?"
                }
            ]
        },
    ],
    "semantic_model_file": "@my_stage/my_semantic_model.yaml"
}

response

responseにはAPIから返ってくるメッセージなどの内容が含まれます。

  • message: ユーザーとアナリスト間の会話
  • message(object): チャット内のメッセージ
    • message.role: userかanalystのいずれか
  • message.content[] (object): メッセージの一部であるコンテンツオブジェクト
    • message.content[].type: メッセージのコンテンツ タイプ。text、suggestion、sqlのいずれか
      • suggestionとsqlは相互に排他で、両方含まれることはない
      • sqlを返せなかった時にsuggestionになる
    • message.content[].text: コンテンツのテキスト。textに対してのみ返される
    • message.content[].statement: SQL ステートメント。sqlに対してのみ返される
    • message.content[].suggestions: SQL を生成できない場合、セマンティックモデルがsqlを生成できる質問のリスト。suggestionに対してのみ返される

responseの例

{
    "message": {
        "role": "analyst",
        "content": [
            {
                "type": "text",
                "text": "We interpreted your question as ..."
            },
            {
                "type": "sql",
                "statement": "SELECT * FROM table"
            }
        ]
    }
}

Streamlitでアプリ作成

チュートリアルのスクリプトをベースにして作りました。

https://docs.snowflake.com/user-guide/snowflake-cortex/cortex-analyst/tutorials/tutorial-1#step-3-create-a-streamlit-app-to-talk-to-your-data-through-cortex-analyst

ローカルで動かす場合

以下は自分で設定する箇所

  • 8-12行目
    • セマンティックモデルのファイルがあるデータベース、スキーマ、ステージ、ファイル名
    • 使用するウェアハウス
  • 15-19行目:アカウント情報
    • ACCOUNT:[組織名]-[アカウント名]
      • HOST:アカウントURLのhttps:// 以降の部分
from typing import Any, Dict, List, Optional

import pandas as pd
import requests
import snowflake.connector
import streamlit as st

DATABASE = "CORTEX_ANALYST_DEMO"
SCHEMA = "REVENUE_TIMESERIES"
STAGE = "RAW_DATA"
FILE = "revenue_timeseries.yaml"
WAREHOUSE = "cortex_analyst_wh"

# replace values below with your Snowflake connection information
HOST = "<host>"
ACCOUNT = "<account>"
USER = "<user>"
PASSWORD = "<password>"
ROLE = "<role>"

if 'CONN' not in st.session_state or st.session_state.CONN is None:
    st.session_state.CONN = snowflake.connector.connect(
        user=USER,
        password=PASSWORD,
        account=ACCOUNT,
        host=HOST,
        port=443,
        warehouse=WAREHOUSE,
        role=ROLE,
    )

def send_message(prompt: str) -> Dict[str, Any]:
    """Calls the REST API and returns the response."""
    request_body = {
        "messages": [{"role": "user", "content": [{"type": "text", "text": prompt}]}],
        "semantic_model_file": f"@{DATABASE}.{SCHEMA}.{STAGE}/{FILE}",
    }
    resp = requests.post(
        url=f"https://{HOST}/api/v2/cortex/analyst/message",
        json=request_body,
        headers={
            "Authorization": f'Snowflake Token="{st.session_state.CONN.rest.token}"',
            "Content-Type": "application/json",
        },
    )
    request_id = resp.headers.get("X-Snowflake-Request-Id")
    if resp.status_code < 400:
        return {**resp.json(), "request_id": request_id}  # type: ignore[arg-type]
    else:
        raise Exception(
            f"Failed request (id: {request_id}) with status {resp.status_code}: {resp.text}"
        )

def process_message(prompt: str) -> None:
    """Processes a message and adds the response to the chat."""
    st.session_state.messages.append(
        {"role": "user", "content": [{"type": "text", "text": prompt}]}
    )
    with st.chat_message("user"):
        st.markdown(prompt)
    with st.chat_message("assistant"):
        with st.spinner("Generating response..."):
            response = send_message(prompt=prompt)
            request_id = response["request_id"]
            content = response["message"]["content"]
            display_content(content=content, request_id=request_id)  # type: ignore[arg-type]
    st.session_state.messages.append(
        {"role": "assistant", "content": content, "request_id": request_id}
    )

def display_content(
    content: List[Dict[str, str]],
    request_id: Optional[str] = None,
    message_index: Optional[int] = None,
) -> None:
    """Displays a content item for a message."""
    message_index = message_index or len(st.session_state.messages)
    if request_id:
        with st.expander("Request ID", expanded=False):
            st.markdown(request_id)
    for item in content:
        if item["type"] == "text":
            st.markdown(item["text"])
        elif item["type"] == "suggestions":
            with st.expander("Suggestions", expanded=True):
                for suggestion_index, suggestion in enumerate(item["suggestions"]):
                    if st.button(suggestion, key=f"{message_index}_{suggestion_index}"):
                        st.session_state.active_suggestion = suggestion
        elif item["type"] == "sql":
            with st.expander("SQL Query", expanded=False):
                st.code(item["statement"], language="sql")
            with st.expander("Results", expanded=True):
                with st.spinner("Running SQL..."):
                    df = pd.read_sql(item["statement"], st.session_state.CONN)
                    if len(df.index) > 1:
                        data_tab, line_tab, bar_tab = st.tabs(
                            ["Data", "Line Chart", "Bar Chart"]
                        )
                        data_tab.dataframe(df)
                        if len(df.columns) > 1:
                            df = df.set_index(df.columns[0])
                        with line_tab:
                            st.line_chart(df)
                        with bar_tab:
                            st.bar_chart(df)
                    else:
                        st.dataframe(df)

st.title("Cortex Analyst")
st.markdown(f"Semantic Model: `{FILE}`")

if "messages" not in st.session_state:
    st.session_state.messages = []
    st.session_state.suggestions = []
    st.session_state.active_suggestion = None

for message_index, message in enumerate(st.session_state.messages):
    with st.chat_message(message["role"]):
        display_content(
            content=message["content"],
            request_id=message.get("request_id"),
            message_index=message_index,
        )

if user_input := st.chat_input("What is your question?"):
    process_message(prompt=user_input)

if st.session_state.active_suggestion:
    process_message(prompt=st.session_state.active_suggestion)
    st.session_state.active_suggestion = None

Streamlit in Snowflakeで動かす場合

一部を書き換える必要があります。

  • インポートするライブラリ
  • send_message()の中身
import _snowflake
import json
import streamlit as st
import time
from snowflake.snowpark.context import get_active_session

DATABASE = "CORTEX_ANALYST_DEMO"
SCHEMA = "REVENUE_TIMESERIES"
STAGE = "RAW_DATA"
FILE = "revenue_timeseries.yaml"

def send_message(prompt: str) -> dict:
    """Calls the REST API and returns the response."""
    request_body = {
        "messages": [
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": prompt
                    }
                ]
            }
        ],
        "semantic_model_file": f"@{DATABASE}.{SCHEMA}.{STAGE}/{FILE}",
    }
    resp = _snowflake.send_snow_api_request(
        "POST",
        f"/api/v2/cortex/analyst/message",
        {},
        {},
        request_body,
        {},
        30000,
    )
    if resp["status"] < 400:
        return json.loads(resp["content"])
    else:
        raise Exception(
            f"Failed request with status {resp['status']}: {resp}"
        )

def process_message(prompt: str) -> None:
    """Processes a message and adds the response to the chat."""
    st.session_state.messages.append(
        {"role": "user", "content": [{"type": "text", "text": prompt}]}
    )
    with st.chat_message("user"):
        st.markdown(prompt)
    with st.chat_message("assistant"):
        with st.spinner("Generating response..."):
            response = send_message(prompt=prompt)
            content = response["message"]["content"]
            display_content(content=content)
    st.session_state.messages.append({"role": "assistant", "content": content})


def display_content(content: list, message_index: int = None) -> None:
    """Displays a content item for a message."""
    message_index = message_index or len(st.session_state.messages)
    for item in content:
        if item["type"] == "text":
            st.markdown(item["text"])
        elif item["type"] == "suggestions":
            with st.expander("Suggestions", expanded=True):
                for suggestion_index, suggestion in enumerate(item["suggestions"]):
                    if st.button(suggestion, key=f"{message_index}_{suggestion_index}"):
                        st.session_state.active_suggestion = suggestion
        elif item["type"] == "sql":
            with st.expander("SQL Query", expanded=False):
                st.code(item["statement"], language="sql")
            with st.expander("Results", expanded=True):
                with st.spinner("Running SQL..."):
                    session = get_active_session()
                    df = session.sql(item["statement"]).to_pandas()
                    if len(df.index) > 1:
                        data_tab, line_tab, bar_tab = st.tabs(
                            ["Data", "Line Chart", "Bar Chart"]
                        )
                        data_tab.dataframe(df)
                        if len(df.columns) > 1:
                            df = df.set_index(df.columns[0])
                        with line_tab:
                            st.line_chart(df)
                        with bar_tab:
                            st.bar_chart(df)
                    else:
                        st.dataframe(df)


st.title("Cortex analyst")
st.markdown(f"Semantic Model: `{FILE}`")

if "messages" not in st.session_state:
    st.session_state.messages = []
    st.session_state.suggestions = []
    st.session_state.active_suggestion = None

for message_index, message in enumerate(st.session_state.messages):
    with st.chat_message(message["role"]):
        display_content(content=message["content"], message_index=message_index)

if user_input := st.chat_input("What is your question?"):
    process_message(prompt=user_input)

if st.session_state.active_suggestion:
    process_message(prompt=st.session_state.active_suggestion)
    st.session_state.active_suggestion = None

Streamlit in Snowflakeの制約にも注意が必要です。
https://docs.snowflake.com/ja/developer-guide/streamlit/limitations

おわりに

基本的に公式ドキュメントをなぞるだけでText to SQLの動くアプリケーションが作れるのはとても魅力的でした。
Text to SQLはデータアナリストやビジネスユーザーとエンジニアを繋ぐポテンシャルがあり、データの民主化をサポートできる技術なんじゃないかなと思います。
2024年10月1日現在ではプレビューで料金も11月15日まで無料とのことですが、実際はコストがどれくらいかかるのかが気になるところです。

Discussion