Open5

RAG の評価管理を weave で実装する

HD-21HD-21

プロジェクトの作成

プロジェクトは以下のように作成。

import weave

PROJECT = "sample"
weave.init(PROJECT)

https://wandb.ai/{username}/projects にアクセスしてプロジェクトが作成されていることを確認

HD-21HD-21

追跡

関数の追跡

OpenAI API の応答生成関数 generate_response を作成
追跡対象の関数に @weave.op のデコレータを付与

# -*- coding: utf_8 -*-
from typing import Literal, TypedDict

import weave
from loguru import logger
from openai import OpenAI

weave.init("sample")


class TypeMessage(TypedDict):
    role: Literal["assistant", "function", "system", "user"]
    content: str


@weave.op()
def generate_response(
    messages: list[dict],
    client: OpenAI = OpenAI(),
    model_name: str = "gpt-4o",
    **kwargs,
):
    response = client.chat.completions.create(
        model=model_name,
        messages=messages,
        **kwargs,
    )
    response_content = response.choices[0].message.content
    logger.trace(f"response_content: {response_content}")
    return response_content



if __name__ == "__main__":
    client = OpenAI()
    messages: list[TypeMessage] = [
        {"role": "user", "content": "Hello"},
    ]
    response_content = generate_response(messages, client)

上記を実行すると以下のメッセージが出力される

🍩 https://wandb.ai/{username}/sample/r/call/{project_id}

アクセスするとこんな感じ

  • Call: 関数呼び出しを追跡可能(引数と返り値を見ることができる)
  • Code: 実行時の関数コード
  • Feedback: フィードバックを表示
  • Summary: 実行時間、OSバージョン、トークン数などが表示される
  • Use: more information

またエラー発生時は、こんな感じに traceback が表示される

クラスの追跡

メソッドに @weave.op を用いるとクラスの属性値も追跡対象となる

HD-21HD-21

オブジェクトの追跡

General Object Tracking

オブジェクトを追跡したい場合は weave.Object を publish することでバージョン管理する。
プロンプト管理したいときなどに有用。

import weave
weave.init(PROJECT)

class SystemPrompt(weave.Object):
    prompt: str

system_prompt = SystemPrompt(
    prompt="You are a grammar checker, correct the following user input."
)
weave.publish(system_prompt)

Model Tracking

weave.Model を継承してクラスも publish することが可能。
モデルを publish する場合は、predict メソッドを定義する必要がある。

class OpenAIGrammarCorrector(weave.Model):
    # Properties are entirely user-defined
    openai_model_name: str
    system_message: str

    @weave.op(name="hello")
    def predict(self, user_input):
        client = OpenAI()
        response = client.chat.completions.create(
            model=self.openai_model_name,
            messages=[
                {"role": "system", "content": self.system_message},
                {"role": "user", "content": user_input},
            ],
            temperature=0,
        )
        return response.choices[0].message.content

corrector = OpenAIGrammarCorrector(
    openai_model_name="gpt-3.5-turbo-1106",
    system_message="You are a grammar checker, correct the following user input.",
)
ref = weave.publish(corrector)

Dataset Tracking

データセットを管理する場合は weave.Dataset を publish することでバージョン管理する。
評価セットなどに用いる。

dataset = weave.Dataset(
    name="grammar-correction",
    rows=[
        {
            "user_input": "   That was so easy, it was a piece of pie!   ",
            "expected": "That was so easy, it was a piece of cake!",
        },
        {"user_input": "  I write good   ", "expected": "I write well"},
        {
            "user_input": "  GPT-3 is smartest AI model.   ",
            "expected": "GPT-3 is the smartest AI model.",
        },
    ],
)
weave.publish(dataset)
HD-21HD-21

Publish したオブジェクトの取得

publish

corrector = OpenAIGrammarCorrector(
    openai_model_name="gpt-3.5-turbo-1106",
    system_message="You are a grammar checker, correct the following user input.",
)
ref = weave.publish(corrector)
logger.debug(ref.uri())

取得して実行

ref_url = f"weave:///{ref.entity}/{PROJECT}/object/{ref.name}:{ref.digest}"
fetched_collector = weave.ref(ref_url).get()
result = fetched_collector.predict("That was so easy, it was a piece of pie!")
logger.debug(result)