❄️

Snowflake で簡単に LLM のファインチューニングをしよう Part1 -教師データ生成アプリを作ってみよう-

2024/12/04に公開

はじめに

本記事は Snowflake Advent Calendar の4日目です。是非 Snowflake の進化を楽しんでいってください!

データ活用の世界でも LLM の利用は当たり前になりつつあります。例えば SQL クエリの生成、データの分析やアドバイスから自律的な分析とレポーティングまでアイディア次第で様々なことが実現可能です。Snowflake では SQL や Python で簡単かつセキュアに利用できる Cortex AI でそれらを実現することが可能となります。

LLM の性能は猛烈な勢いで上がってきており、現状次々に新しいモデルが出てきているため、ベースモデルに手を加えずに Few-Shot や RAG のようなプロンプトエンジニアリングの手法で活用するのが手軽なのでオススメではあります。Snowflake における RAG については以前投稿した以下の記事をご参照ください。

https://zenn.dev/tsubasa_tech/articles/200e72d6039acf

しかし、ビジネスにおいてはベースモデルでは解決が難しいシーンがあります。例えば以下のようなシーンではプロンプトエンジニアリング単体だと求めた結果が得られないことが想定されます。

  • プロンプトエンジニアリングでは毎回大量の背景情報を渡す必要がある場合
  • 業界特有の用語や知識 (=ドメイン知識) が求められる場合
  • 品質を高めるために巨大なパラメータ数のモデルが必要となりコストが増大する場合

この状況を打破する選択肢の1つがファインチューニングとなります。ファインチューニングとはベースモデルを追加のデータで調整する手法で、ビジネスの目的に合わせて LLM の性質を変えることが可能です。

ただし一般にファインチューニングは大変です。LLM 開発に関する専門的なスキル、教師データの準備とデータ変換、複雑な AI/Ops パイプラインの構築と管理、ファインチューニングモデルの管理など様々なハードルが存在します。加えて、ファインチューニングが求めたレベルの品質に達するまでにはトライ&エラーが必要となるため、満足な結果を得られる前にプロジェクトが頓挫しがちです。

そこで Snowflake の Cortex Fine-tuning の機能が役立ちます。この機能を知った時にこれまでファインチューニング反対派だった私の考えが変わりました。何故なら想像以上に現実的かつ実用的だったからです!

Snowflake Cortex Fine-tuning

Cotrex Fine-tuning とは Snowflake の生成 AI 機能群である Cortex LLM の機能の1つです。Cortex LLM の一部のベースモデルを関数を呼び出すことで簡単にファインチューニングすることができ、通常の Cortex LLM と同じ使用感で推論することができます。

Cortex Fine-tuning のドキュメント

この Cortex Fine-tuning の特に優れたところはトレーニングのしやすさです。以下のような FINETUNE 関数を呼び出すだけでトレーニングが始まりファインチューニングモデルができあがります。

-- 学習クエリ
SELECT SNOWFLAKE.CORTEX.FINETUNE(
  'CREATE',
  '<ファインチューニングモデル名>',
  '<ベースモデル名>',
  'SELECT a AS prompt, d AS completion FROM <教師データテーブル>'
);

特に教師データの渡し方が楽です。Snowflake 上に質問にあたる prompt と 回答にあたる completion の最低2カラムのテーブルを用意しておけば良いので、JSON などにする必要すら無く、既存のデータから簡単な変換を噛ませるだけで教師データを作ることが可能です。そのため、従来大変だった AI/Ops パイプラインをかなり省略することができ、実用性が高いと言えます。

ちなみに推論する際は通常の Cortex LLM の COMPLETE 関数と同様に以下のようなクエリとなり、こちらも楽です。

-- 推論クエリ
SELECT SNOWFLAKE.CORTEX.COMPLETE(
  '<ファインチューニングモデル名>',
  'あなたは何ができますか?'
);

対応リージョン

まずトレーニングについては以下のリージョンが対応しています。


2024/12/4現在の対応リージョン

ただし推論については上記の限りではありません。このファインチューニングしたモデルについては共有やレプリケーションすることもできるため、推論については Cotrex LLM でベースモデルが配備されているリージョン であれば行うことが可能です。今回は実際に AWS/Tokyo リージョンでファインチューニングモデルを使ってみたいと思います。

対応ベースモデル

ファインチューニングに対応しているベースモデルは以下の通りです。

Cortex Fine-tuning の対応ベースモデル


2024/12/4現在の対応ベースモデル

llama3.1-70b なども対応しているため、日本語に強いモデルも作ることが可能です。

機能概要

今回やること

  • ★Cortex Fine-tuning を試してみる
  • ★(Option) 教師データ生成アプリを作成する
  • ファインチューニングモデルに対応したチャットアプリを作成する
  • ファインチューニングモデルを他のリージョンで使ってみる

★:Part1で実現する範囲

教師データ生成アプリの機能一覧

  • ユーザーが入力したトピックから教師データを自動生成
  • 自動生成された教師データの全て or 一部をテーブルに保存
  • ユーザーによる手動での教師データ作成
  • ユーザーによる既存の教師データの編集
  • 登録された教師データの確認

Part1 の完成イメージ


教師データの自動生成機能


教師データの手動作成機能


教師データの手動編集機能


教師データの確認機能

前提条件

  • Cortex Fine-tuning に対応しているリージョンの Snowflake アカウント
  • Streamlit in Snowflake のインストールパッケージ
    • snowflake-ml-python 1.6.4

まずは Cortex Fine-tuning を試してみる

それでは本格的な作業を始める前に、Cortex Fine-tuning を軽く触ってみましょう。テスト用の教師データとして Hugging Face で公開されている japanese_alpaca_data を使わせていただきます。

適当なデータベース、スキーマ、ロールを準備した上で Hugging Face から Parquet 形式のファイルをダウンロードし、以下の手順を参考にして Snowflake のテーブルにデータをロードします。ここではテーブル名を japanese_alpaca_data とします。

https://zenn.dev/tsubasa_tech/articles/c0a2b8793a5d1f#jsts-データをテーブルにアップロード

ワークシートから以下のクエリを実行し加工前のデータを確認します

-- オリジナルのデータセットを確認
SELECT * FROM japanese_alpaca_data LIMIT 10;


加工前のデータを確認

INSTRUCTION が質問、INPUT が質問の補足、OUTPUT が回答に当たりそうなので、以下のクエリを実行し教師データとして少し加工します。

-- 教師データ用テーブルの作成
CREATE TABLE japanese_alpaca_data_train (
  prompt varchar,
  completion varchar
);

-- オリジナルのデータセットから教師データに変換
INSERT INTO japanese_alpaca_data_train (prompt, completion)
SELECT 
    CONCAT(COALESCE(instruction, ''), COALESCE(input, '')) AS prompt,
    output AS completion
FROM japanese_alpaca_data;

-- 教師データからNULLが含まれたレコードの削除
DELETE FROM japanese_alpaca_data_train
WHERE prompt IS NULL OR completion IS NULL;

-- 教師データの確認
SELECT * FROM japanese_alpaca_data_train LIMIT 10;


加工後の教師データを確認

教師データが揃ったら後は FINETUNE 関数でトレーニングをするだけです。短時間で試すために比較的小さなモデル llama3-8b で試してみましょう。

-- 学習クエリ
SELECT SNOWFLAKE.CORTEX.FINETUNE(
  'CREATE',
  'llama3_8b_finetuned',
  'llama3-8b',
  'SELECT prompt, completion FROM japanese_alpaca_data_train'
);

以下クエリで学習状況を確認します。

-- トレーニングジョブの一覧からジョブIDを確認
SELECT SNOWFLAKE.CORTEX.FINETUNE('SHOW');

このクエリで確認したジョブ ID を使って以下クエリで学習の詳細情報を確認します。

-- ジョブIDからトレーニングジョブの詳細確認
SELECT SNOWFLAKE.CORTEX.FINETUNE(
  'DESCRIBE',
  '<ジョブID>'
);

progress の値は学習の進捗を表しており、ここの値が1になったら学習完了です。学習が完了すると statusIN_PROGRESS から SUCCESS に変わります。

学習が完了したら推論をしてみましょう!推論は Cortex LLM の COMPLETE 関数で行うことができます。試しにベースモデルと比較してみましょう。

-- 推論クエリ (ベースモデル)
SELECT SNOWFLAKE.CORTEX.COMPLETE(
  'llama3-8b',
  'あなたは何ができますか?'
);

-- 推論クエリ (ファインチューニングモデル)
SELECT SNOWFLAKE.CORTEX.COMPLETE(
  'llama3_8b_finetuned',
  'あなたは何ができますか?'
);


ベースモデルの推論結果


ファインチューニングモデルの推論結果

このようにベースモデルでは日本語で聞いても英語で回答が返ってしまいますが、ファインチューニングモデルでは日本語で返してくれるなど明らかにモデルが調整されていることが分かります。

(Option) 教師データ生成アプリを作成する

ここからは本格的なファインチューニングに取り組んでいきます。まずはファインチューニングの最も重要な要素の1つである教師データを準備していきましょう。本手順では LLM を駆使して教師データを生成するアプリを作っていきますので、既にファインチューニング用の学習データが手元に揃っている場合は本手順は飛ばしていただいて問題ございません。

いつも通り雑ではありますが Streamlit in Snowflake のアプリを用意しましたので、新規で Streamlit in Snowflake のアプリを作成し、前提条件に記載のパッケージ snowflake-ml-python をインストールした上で以下コードをコピー&ペーストで貼り付けて実行してください。

このアプリを利用することで finetune_training_data という教師データテーブルができあがります。特定のテーマを入力すると Cortex LLM が自動的に教師データを生成してくれますので、そのデータを元に手動で修正したり追加したりすることでオリジナルの教師データを準備することが可能です。

テクニックとしてはなるべくパラメータ数が多く賢いモデルで教師データを作ると良いです。ファインチューニングの価値の1つとして、少ないパラメータのモデルでも高品質な結果を出すというところもあるため、mistral-large2 を先生として選んであげるといいかもしれません。

from snowflake.snowpark.context import get_active_session
import streamlit as st
from snowflake.cortex import Complete as CompleteText
from datetime import datetime
import pandas as pd
import time

# Snowflakeセッションの取得
session = get_active_session()

# テーブル作成用のSQL
CREATE_TABLE_SQL = """
CREATE TABLE IF NOT EXISTS FINETUNE_TRAINING_DATA (
    id NUMBER AUTOINCREMENT,
    prompt VARCHAR,
    completion VARCHAR,
    created_at TIMESTAMP_NTZ DEFAULT CURRENT_TIMESTAMP(),
    updated_at TIMESTAMP_NTZ DEFAULT CURRENT_TIMESTAMP()
)
"""

# 教師データテーブルの作成
def init_table():
    try:
        session.sql(CREATE_TABLE_SQL).collect()
    except Exception as e:
        st.error(f"テーブルの作成中にエラーが発生しました: {str(e)}")

# 教師データの読み込み
def load_training_data():
    try:
        return session.sql("SELECT * FROM FINETUNE_TRAINING_DATA ORDER BY id").collect()
    except Exception as e:
        st.error(f"データ読み込み中にエラーが発生しました: {str(e)}")
        return []

# 既存教師データのプロンプトを取得
def get_existing_prompts():
    try:
        result = session.sql("SELECT prompt FROM FINETUNE_TRAINING_DATA").collect()
        return [row['PROMPT'] for row in result]
    except Exception as e:
        st.error(f"既存教師データのプロンプト取得中にエラーが発生しました: {str(e)}")
        return []

# 教師データの保存
def save_training_data(prompt, completion):
    if not prompt or not completion:
        st.warning("プロンプトと回答の両方を入力してください。")
        return False
    
    try:
        escaped_prompt = prompt.replace("'", "''").replace("\n", " ")
        escaped_completion = completion.replace("'", "''").replace("\n", " ")
        
        sql = f"""
        INSERT INTO FINETUNE_TRAINING_DATA (prompt, completion)
        SELECT '{escaped_prompt}', '{escaped_completion}'
        """
        session.sql(sql).collect()
        return True
    except Exception as e:
        st.error(f"データ保存中にエラーが発生しました: {str(e)}")
        return False

# 生成されたペアを一括保存
def save_all_pairs(pairs):
    success_count = 0
    for pair in pairs:
        try:
            if save_training_data(pair['prompt'], pair['completion']):
                success_count += 1
        except Exception as e:
            st.error(f"データ保存中にエラーが発生しました: {str(e)}")
    return success_count

# 教師データの更新
def update_training_data(id, prompt, completion):
    if not prompt or not completion:
        st.warning("プロンプトと回答の両方を入力してください。")
        return False
    
    try:
        escaped_prompt = prompt.replace("'", "''").replace("\n", " ")
        escaped_completion = completion.replace("'", "''").replace("\n", " ")
        
        sql = f"""
        UPDATE FINETUNE_TRAINING_DATA
        SET prompt = '{escaped_prompt}',
            completion = '{escaped_completion}',
            updated_at = CURRENT_TIMESTAMP()
        WHERE id = {id}
        """
        session.sql(sql).collect()
        return True
    except Exception as e:
        st.error(f"データ更新中にエラーが発生しました: {str(e)}")
        return False

# 教師データの削除
def delete_training_data(id):
    try:
        session.sql(f"DELETE FROM FINETUNE_TRAINING_DATA WHERE id = {id}").collect()
        return True
    except Exception as e:
        st.error(f"データ削除中にエラーが発生しました: {str(e)}")
        return False

# プロンプトの生成
def generate_prompt(topic, index, previous_prompts, existing_prompts):
    all_prompts = existing_prompts + previous_prompts
    prompts_text = "\n".join([f"- {p}" for p in all_prompts]) if all_prompts else "なし"
    
    prompt_gen_template = f"""
    LLMのFinetuning用の教師データを作るにあたり、
    「{topic}」というテーマについて、{index}番目の質問を生成してください。
    
    以下の既存の質問と重複しないように、新しい質問を生成してください:
    {prompts_text}
    
    質問は具体的で明確な形で作成していただき、改行を含めず1行で出力してください。
    """
    return CompleteText(lang_model, prompt_gen_template).strip()

# 回答の生成
def generate_completion(prompt):
    completion_gen_template = f"""
    以下の質問に対して、簡潔で的確な回答を生成してください:
    質問:{prompt}

    回答は改行を含めず1行で出力してください。
    """
    return CompleteText(lang_model, completion_gen_template).strip()

# UI設定
st.set_page_config(layout="wide")
st.title("教師データ作成アプリ")

# テーブルの作成
init_table()

# セッション状態の初期化
if 'generated_pairs' not in st.session_state:
    st.session_state.generated_pairs = []

# サイドバー設定
st.sidebar.title("設定")
lang_model = st.sidebar.selectbox(
    "AIモデルの選択",
    ["snowflake-arctic",
     "reka-core", "reka-flash",
     "mistral-large2", "mistral-large", "mixtral-8x7b", "mistral-7b",
     "llama3.2-3b", "llama3.2-1b",
     "llama3.1-405b", "llama3.1-70b", "llama3.1-8b",
     "llama3-70b", "llama3-8b", "llama2-70b-chat",
     "jamba-1.5-large", "jamba-1.5-mini", "jamba-instruct",
     "gemma-7b"]
)

# メインコンテンツ
tab1, tab2, tab3 = st.tabs(["教師データ自動生成", "手動編集", "データ確認"])

with tab1:
    st.header("LLMを使用した教師データの自動生成")
    
    topic = st.text_input("トピック")
    num_pairs = st.number_input("生成するペア数", min_value=1, max_value=100, value=5)
    
    if st.button("教師データ生成", type="primary"):
        if not topic:
            st.warning("トピックを入力してください。")
        else:
            progress_bar = st.progress(0)
            status_text = st.empty()
            
            existing_prompts = get_existing_prompts()
            generated_prompts = []
            st.session_state.generated_pairs = []
            
            for i in range(num_pairs):
                try:
                    previous_prompts = generated_prompts
                    
                    status_text.text(f"プロンプト {i+1}/{num_pairs} を生成中...")
                    prompt = generate_prompt(topic, i+1, previous_prompts, existing_prompts)
                    generated_prompts.append(prompt)
                    time.sleep(1)
                    
                    status_text.text(f"回答 {i+1}/{num_pairs} を生成中...")
                    completion = generate_completion(prompt)
                    time.sleep(1)
                    
                    st.session_state.generated_pairs.append({
                        'prompt': prompt,
                        'completion': completion
                    })
                    
                    progress_bar.progress((i + 1) / num_pairs)
                    
                except Exception as e:
                    st.error(f"生成中にエラーが発生しました: {str(e)}")
                    break
            
            status_text.text("教師データの生成が完了しました!")
            st.snow()

    if st.session_state.generated_pairs:
        st.subheader("生成結果")
        
        if st.button("全ての結果を保存", type="primary"):
            success_count = save_all_pairs(st.session_state.generated_pairs)
            st.success(f"{success_count}件のデータを保存しました")
        
        for i, pair in enumerate(st.session_state.generated_pairs):
            with st.expander(f"生成結果 {i+1}/{len(st.session_state.generated_pairs)}", expanded=True):
                st.write("プロンプト:", pair['prompt'])
                st.write("回答:", pair['completion'])
                
                if st.button(f"この結果を保存 {i+1}", key=f"save_{i}"):
                    if save_training_data(pair['prompt'], pair['completion']):
                        st.success("保存しました")

with tab2:
    st.header("教師データの手動編集")
    
    with st.expander("新規データを追加", expanded=True):
        new_prompt = st.text_area("プロンプト", height=100)
        new_completion = st.text_area("完了テキスト", height=200)
        if st.button("追加", type="primary"):
            if save_training_data(new_prompt, new_completion):
                st.success("追加しました!")
                st.experimental_rerun()
    
    st.subheader("既存データの編集・削除")
    training_data = load_training_data()
    
    for row in training_data:
        with st.expander(f"ID: {row['ID']} - {row['PROMPT'][:50]}..."):
            edited_prompt = st.text_area("プロンプト編集", row['PROMPT'], 
                                       key=f"prompt_{row['ID']}", height=100)
            edited_completion = st.text_area("完了テキスト編集", row['COMPLETION'], 
                                           key=f"completion_{row['ID']}", height=200)
            
            col1, col2 = st.columns(2)
            with col1:
                if st.button("更新", key=f"update_{row['ID']}", type="primary"):
                    if update_training_data(row['ID'], edited_prompt, edited_completion):
                        st.success("更新しました!")
                        st.experimental_rerun()
            with col2:
                if st.button("削除", key=f"delete_{row['ID']}", type="secondary"):
                    if delete_training_data(row['ID']):
                        st.success("削除しました!")
                        st.experimental_rerun()

with tab3:
    st.header("教師データの確認")
    if training_data:
        df = pd.DataFrame(training_data)
        st.dataframe(df, use_container_width=True)
    else:
        st.info("データがありません。")

# データ統計の表示
st.sidebar.markdown("---")
st.sidebar.subheader("データ統計")
total_records = len(training_data)
st.sidebar.metric("総データ数", total_records)

最後に

今回はまず Cortex Fine-tuning を公開データをベースに試してみました。Cortex Fine-tuning は最低2カラムの教師データテーブルさえあればすぐに学習を開始でき、また推論も Cortex LLM の COMPLETE 関数でサクッとできることが分かっていただけたのではないかと思います。

また Option で作成した教師データ生成アプリを用いることで、オリジナルの教師データも作りやすくなると思います。発展のアイディアとしてはファインチューニングしたモデルで教師データを生成させたりするとより特定領域に特化した教師データが作れたりするかもしれません。

Part2 ではオリジナルの教師データを用いてチャットボットとして利用したり、Cortex Fine-tuning が現状未対応の AWS/Tokyo リージョンで推論するなどを試していきたいと思いますのでご期待ください!

宣伝

生成AI Conf 様の Webinar で登壇します!

『生成AI時代を支えるプラットフォーム』というテーマの Webinar で NVIDIA 様、古巣の AWS 様と共に Snowflake 社員としてデータ*AI をテーマに LTをします!2024/12/16 12:00 - 13:00ですので皆様是非ご視聴いただければ嬉しいですmm

https://generative-ai-conf.connpass.com/event/337737/

SNOWFLAKE WORLD TOUR TOKYO のオンデマンド配信中!

Snowflake の最新情報を知ることができる大規模イベント『SNOWFLAKE WORLD TOUR TOKYO』が2024/9/11-12@ANAインターコンチネンタル東京で開催されました。
現在オンデマンド配信中ですので数々の最新のデータ活用事例をご覧ください。
また私が登壇させていただいた『今から知ろう!Snowflakeの基本概要』では、Snowflakeのコアの部分を30分で押さえられますので、Snowflake をイチから知りたい方、最新の Snowflake の特徴を知りたい方は是非ご視聴いただければ嬉しいですmm

https://www.snowflake.com/events/snowflake-world-tour-tokyo/

X で Snowflake の What's new の配信してます

X で Snowflake の What's new の更新情報を配信しておりますので、是非お気軽にフォローしていただければ嬉しいです。

日本語版

Snowflake の What's New Bot (日本語版)
https://x.com/snow_new_jp

English Version

Snowflake What's New Bot (English Version)
https://x.com/snow_new_en

変更履歴

(20241204) 新規投稿

Discussion