🦉

Generative AI Studio でファインチューニングする

2023/06/16に公開

こんにちは、クラウドエース Data/ML ディビジョン所属の坂田です。

先日、Google I/O 2023 が開催され、Google からさまざまな新サービス・新商品について発表されました。その中で多くの注目を浴びたのは、生成 AI に関する発表です。Google I/O 2023 では、新たな大規模言語モデル(以下、LLM)である「PaLM 2」や様々なサービスに生成 AI が組み込まれることが発表されました。

この記事では、今回発表された Google Cloud の生成 AI のサービスの中から、「Generative AI Studio」の「ファインチューニング」の機能について解説・検証します。

1. Generative AI Studio とは

Generative AI Studio は、Google が持つ基盤モデルを API として利用できるサービスであり、Vertex AI のサービス群の 1 つです。

Generative AI Studio では「言語」と「音声」について以下のような機能が提供されています(執筆時点)。

【言語】

  • 機能
    • AI とチャット形式で会話
    • プロンプトのテスト
    • サンプルのプロンプトの使用
    • ファインチューニングによる調整済みモデルの作成・デプロイ
  • 利用できる基盤モデル
    • PaLM 2:テキスト関連(チャット、テキスト生成など)
    • Codey:コード関連(コード生成、コードに関するチャット形式での QA 対応など)

【音声】

  • 機能
    • テキストの読み上げ(Text-to-Speech)
    • 音声の文字変換(Speech-to-Text)
  • 利用できる基盤モデル
    • Chirp

言語において、基盤モデルは特定のユースケースに応じてファインチューニングされており、チャット形式の会話は「chat-bison」「codechat-bison」、テキスト・コード生成は「text-bison」「code-bison」というモデルが用意されています。

利用可能なモデルの詳細はこちら。
https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models

2. ファインチューニングとは

生成 AI から出力されるテキスト・画像などは、生成 AI を作成する際に使用された「学習データ」がもとになっています。これはつまり、学習データに含まれていない事柄について生成 AI が生成することは難しいということです。

一方で、以下のような「学習データに含まれていない特定のデータ・タスク」を AI に生成・実行させたい場合があります。

  • 特定分野の専門用語(医療、法律など)を理解したチャットボット
  • 任意の形式でテキストを生成させる
  • 特定の画家の表現を含んだイラストを生成させる

大規模言語モデル(以下、LLM)とよばれるような生成 AI は汎用性の高い自然言語処理能力・画像処理能力を持ちます。これらの能力を活かしつつ、特定のタスクをこなせるように、既存の生成 AI をさらに学習させるのがファインチューニング(fine-tuning、日本語で微調整)です。

3. Generative AI Studio のファインチューニング

Generative AI Studio では、テキスト生成のモデル(text-bison-001)をファインチューニングすることができます。

https://cloud.google.com/vertex-ai/docs/generative-ai/models/tune-models

ファインチューニングにより、以下のようなことを実現できます。

  • 任意の構造のプロンプトを学習する
  • 分類・要約・抽出的 QA といった特定のタスクのみを実行する
  • 任意の形式に沿った出力を行う

なお、公式ドキュメントによると、特定分野の専門用語(医療、法律など)を理解させることは現状難しいとのことです。

4. ファインチューニングしてみる

今回は「text-bison-001」に対して、テキスト分類を行うようにファインチューニングします。

扱うテーマ

データ分析コンペの Kaggle に「Natural Language Processing with Disaster Tweets(災害時のツイートに関する自然言語処理)」という入門コンペがあります。このコンペは Twitter のツイートの文章から、そのツイートが実際の災害についてツイートしたものかどうかを分類するコンペです。災害時に特有の言葉や感情をツイートから抽出するという、自然言語処理が必要になります。

今回はこのコンペのデータを使い、自然言語処理の分類タスクを LLM に適用します。
入出力の関係は以下となります。

  • 入力:ツイート文
  • 出力:災害である、災害ではない(disaster、not disaster)

流れ

以下の流れでファインチューニングを行います。

  1. 訓練データ・テストデータを作成する
  2. 訓練データを Cloud Storage バケットにアップロードする
  3. Generative AI Studio でファインチューニングを実行する
  4. テストデータを使ってテストを行う

手順 ① 訓練データ・テストデータを作成する

データをダウンロード

Kaggle からデータをダウンロードします。ダウンロードした CSV ファイルは以下のカラムとなっていました。

  • id
  • keyword
  • location
  • text
  • target

この内、text(ツイート)と target(1 が災害、0 が非災害を示す)のみを使用します。

データの形式

従来の自然言語処理の分類タスクでは、訓練時は前処理済みの自然言語処理対象のテキストをモデルに入力していましたが、生成 AI ではプロンプトを訓練データとして使用します

ドキュメントを参考に、以下のようなプロンプトを作成しました。
記事執筆時点では、Generative AI Studio の生成 AI は日本語に対応していないため、英語を使用しました。

Classify the category of the following text as disaster or not disaster.
Text: {}
Category:

{} の部分にはツイート文が入ります。
また、正解ラベルは、disaster または not disaster とします。こうすることで、AI に文章ではなく、disasternot disaster のみを出力させるようにします。

プロンプトを作成する際は以下のような点を考慮します。

  • AI に渡すデータには接頭辞(今回は Text:)を付ける
  • AI に出力して欲しい部分(今回は Category:)に接頭辞を付ける
  • 命令文内の言葉と接頭辞を統一する(今回は「Classify the category of the following text」)

Generative AI Studio では、訓練データのファイル形式は JSONL が指定されています。そのため、1 レコードで以下のような形式となります。

{"input_text": "Classify the category of the following text as disaster or not disaster.\nText: Forest fire near La Ronge Sask. Canada\nCategory:", "output_text": "disaster"}

入力は input_text、出力は output_text です。

Twitter はツイートの文章を改行することができるため、Kaggle からダウンロードしたデータも改行ありのツイートが多く含まれていました。ツイートの改行を表現するとプロンプトが崩れてしまい、学習に影響が出る可能性を考えたため、ツイート内の改行コードは空白文字に変換しました。一方で、プロンプト自体の改行は \n と表現しています。

データの個数

ドキュメントによると、分類タスクでは 100 個以上のデータを用意することが推奨されているとのことですので、500 個の訓練データを用意します。

訓練データ 500 個の内、disaster ラベルと not disaster ラベルの割合を 1 対 1 としました。また、テストデータも同様に 1 対 1 の割合で合計 200 個作成しました。

データ作成のコード

以下は訓練データ作成用の Python コードです。

import csv
import json

data = []
dict_list_disaster = []
dict_list_not_disaster = []

prompt_template = '''Classify the category of the following text as disaster or not disaster.
Text: {}
Category:'''

# データを読み込む
with open('./nlp-getting-started/train.csv', 'r') as f:
    # 改行コードはエスケープしてdict型にする
    reader = csv.DictReader(f, escapechar='\\')
    for row in reader:
        data.append(row)

# データを整形
for row in data:
    # ツイート内の改行コードを空白文字に変換
    text = row['text'].replace('\n', ' ')
    target = row['target']
    # ツイートをプロンプトに埋め込む
    input_text = prompt_template.format(text)
    
    if target == '1':
        dict_list_disaster.append({'input_text':input_text, 'output_text':'disaster'})
    elif target == '0':
        dict_list_not_disaster.append({'input_text':input_text, 'output_text':'not disaster'})

# trainとtestに分割
train_dict_list = dict_list_disaster[:250] + dict_list_not_disaster[:250]
test_dict_list = dict_list_disaster[250:350] + dict_list_not_disaster[250:350]

# jsonlとして出力
with open('train.jsonl', mode='w', encoding='utf-8') as f:
    for obj in train_dict_list:
        json.dump(obj, f, ensure_ascii=False)
        f.write('\n')

with open('test.jsonl', mode='w', encoding='utf-8') as f:
    for obj in test_dict_list:
        json.dump(obj, f, ensure_ascii=False)
        f.write('\n')

手順 ② 訓練データを Cloud Storage バケットにアップロードする

作成した訓練データを適当な Cloud Storage バケットにアップロードします。バケットのロケーションについて特に制限はないようです。

手順 ③ Generative AI Studio でファインチューニングを実行する

今回は WebUI のコンソール画面で操作します。Python や API でも操作可能とのことです。

  1. Generative AI Studio にアクセス
  2. 「言語」>「調整済みモデルの作成」の順に選択
  3. アップロードした訓練データの Cloud Storage 上のパスを入力
  4. 「モデル名」「ベースモデル」「ステップをトレーニング」「作業ディレクトリパス」を入力
    現在、ベースモデルは text-bison-001(テキスト生成モデル)のみ選択可能です。
    ドキュメントによると、ステップ数は分類タスクでは 100〜500 を設定することが推奨されているようです。今回は 500 を設定します。
  5. 「調整を開始」を押す

「調整を開始」を押すと、自動的に Vertex AI Pipelines でファインチューニング用の機械学習パイプラインが自動的に作成されます。今回はパイプラインの実行完了まで約 4 時間かかりました。

実行が完了すると、チューニングされたモデルが自動的に Vertex AI Model Registory に登録され、エンドポイントにデプロイされます。そのため、実行完了したらすぐにモデルのテストを行うことができます。

手順 ④ テストデータを使ってテストを行う

コンソール画面、Python、API の 3 種類の方法で予測リクエストを行うことができます。今回はコンソールと Python で予測リクエストを投げてみます。

コンソールから予測リクエスト

以下の手順で行います。

  1. Generative AI Studio にアクセス
  2. 「言語」>「TUNING」の順に選択
  3. 今回作成したモデルの欄で「テスト」を選択
  4. プロンプトを入力
  5. 「送信」ボタンを押す

【送信したプロンプト】

Classify the category of the following text as disaster or not disaster.
Text: I waited 2.5 hours to get a cab my feet are bleeding
Category:

Python で予測リクエスト

Vertex AI のクライアントライブラリで、エンドポイントにデプロイされたモデルに対して予測リクエストを行うことができます。

まず、認証を通します。

$ gcloud auth application-default login

必要な Python ライブラリをインストールします。

$ pip install google-cloud-aiplatform

以下は予測リクエストをエンドポイントに送信する Python コードです。以下を実行します。

import vertexai
from vertexai.preview.language_models import TextGenerationModel

def predict_large_language_model_sample(
    project_id: str,
    model_name: str,
    temperature: float,
    max_decode_steps: int,
    top_p: float,
    top_k: int,
    content: str,
    location: str,
    tuned_model_name: str,
    ) :
    vertexai.init(project=project_id, location=location)
    model = TextGenerationModel.from_pretrained(model_name)

    if tuned_model_name:
      model = model.get_tuned_model(tuned_model_name)
    
    response = model.predict(
        content,
        temperature=temperature,
        max_output_tokens=max_decode_steps,
        top_k=top_k,
        top_p=top_p,)
    
    print(f"Response from Model: {response.text}")

content = '''Classify the category of the following text as disaster or not disaster.\nText: I waited 2.5 hours to get a cab my feet are bleeding\nCategory:'''

predict_large_language_model_sample(project_id, "text-bison@001", 0.2, 256, 0.8, 40, content, location, tuned_model_name)

predict_large_language_model_sample() の引数 project_idlocationtuned_model_name は適宜変更する必要があります。また、content に送信するプロンプトを入れています。

この記事では触れませんが、model.predict() の引数にモデルのパラメータを渡すことができ、パラメータを調整することで出力結果の調整することが可能です。詳細はこちら。
https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models#parameter_definitions

5. まとめ

データ作成 → モデルのファインチューニング → 予測 の流れを問題なく簡単に行うことができました。

ファインチューニングは 1 から機械学習モデルを作成する際と比較して、少ないデータ量で学習ができます。そのため、Vertex AI で用意されている強力な生成 AI の基盤モデルを、ユースケースに応じて簡単に調整できます。

ファインチューニングの機能は生成 AI をシステムの中で実運用していく際に核となる機能だと思います。今後の改良と GA が楽しみです。

6. 関連記事

他の Generative AI 関連プロダクトの紹介記事はこちら。

https://zenn.dev/cloud_ace/articles/726a56badddf65

https://zenn.dev/cloud_ace/articles/411f4fd7359094

Discussion