🪄

LLMファクトリークラスを実装してLangchainのモデル変更を簡単にした話

2024/09/17に公開

tips: 9割Claudeで記述してます。すごく便利です!

はじめに

大規模言語モデル(LLM)の急速な発展により、多くの開発者が複数のLLMプロバイダーやモデルを使用するようになっています。しかし、これらの異なるAPIや設定を管理することは、しばしば煩雑になります。この記事では、この課題に対処するための効果的なソリューション、LLMファクトリークラスの実装について解説します。

LLMファクトリークラスの概要

LLMファクトリークラスは、ファクトリーパターンを応用した設計手法です。このアプローチにより、異なるLLMプロバイダーやモデルを統一的なインターフェースで扱うことができ、コードの可読性と保守性が大幅に向上します。

実装の詳細

今回のLLMファクトリークラスは、主に2つのクラスで構成されています:

  1. LLMFactory: チャットモデルとLLMオブジェクトの生成を担当
  2. EmbeddingFactory: 埋め込みモデルの生成を担当

LLMFactory

class LLMFactory:
    @staticmethod
    def get_chat(mode: str, model: str) -> BaseChatModel:
        # チャットモデルの生成ロジック

    @staticmethod
    def get_llm(mode: str, model: str) -> OpenAI | Anthropic | None:
        # LLMオブジェクトの生成ロジック

LLMFactoryクラスは、get_chatget_llmという2つの静的メソッドを提供します。これらのメソッドは、指定されたモードとモデルに基づいて適切なチャットモデルまたはLLMオブジェクトを返します。

EmbeddingFactory

class EmbeddingFactory:
    @staticmethod
    def get_embedding(mode: str, model: str) -> Embeddings | None:
        # 埋め込みモデルの生成ロジック

EmbeddingFactoryクラスは、get_embeddingメソッドを通じて、指定されたモードとモデルに基づいて適切な埋め込みモデルを返します。

主要な特徴

  1. モジュール性: 各プロバイダーやモデルの初期化ロジックがカプセル化されており、新しいプロバイダーの追加が容易になっています。

  2. 環境変数の活用: APIキーなどの機密情報は環境変数から取得し、セキュリティを確保しています。

  3. エラーハンドリング: サポートされていないモードに対してはNotImplementedErrorを発生させ、適切なエラー処理を促します。

  4. 柔軟性: チャット、LLM、埋め込みなど、異なるタイプのモデルに対応し、プロジェクトの要件に応じて柔軟に使用できます。

使用例

以下は、LLMファクトリークラスを使用する簡単な例です:

# チャットモデルの取得
chat_model = LLMFactory.get_chat("gpt", "gpt-3.5-turbo")

# LLMオブジェクトの取得
llm = LLMFactory.get_llm("claude", "claude-2")

# 埋め込みモデルの取得
embedding_model = EmbeddingFactory.get_embedding("gpt", "text-embedding-ada-002")

この例では、異なるプロバイダーのモデルを統一的なインターフェースで簡単に取得できることがわかります。

まとめと今後の展望

LLMファクトリークラスの実装により、異なるLLMプロバイダーやモデルを効率的に管理することが可能になります。これにより、コードの可読性が向上し、新しいモデルやプロバイダーの追加も容易になります。また、環境変数を利用することでセキュリティも確保されています。

今後のNLP開発において、このようなファクトリークラスの活用は、効率的で柔軟なシステム設計に大きく貢献するでしょう。さらに、この設計パターンを拡張し、モデルのバージョン管理や性能モニタリングなどの機能を追加することで、より強力なLLM管理システムを構築することも可能です。

LLMの急速な進化に伴い、このようなツールの重要性はますます高まっていくと予想されます。開発者の皆さんも、ぜひ自身のプロジェクトでLLMファクトリークラスの導入を検討してみてください。

コード全体

import os

from langchain_anthropic import Anthropic, ChatAnthropic
from langchain_community.chat_models import ChatOllama, ChatPerplexity
from langchain_community.embeddings import CohereEmbeddings, OllamaEmbeddings
from langchain_core.embeddings import Embeddings
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_openai import ChatOpenAI, OpenAI, OpenAIEmbeddings


class LLMFactory:
    @staticmethod
    def get_chat(mode: str, model: str) -> BaseChatModel:
        """Get chat object based on the mode and model"""
        if mode == "gpt":
            return ChatOpenAI(
                openai_api_key=os.environ.get("OPENAI_API_KEY"),
                model=model,
            )
        if mode == "claude":
            return ChatAnthropic(
                model=model,
                anthropic_api_key=os.environ.get("ANTHROPIC_API_KEY"),
            )
        if mode == "perplexity":
            return ChatPerplexity(
                model=model,
                api_key=os.environ.get("PERPLEXITY_API_KEY"),
            )
        if mode == "ollama":
            return ChatOllama(model=model, base_url=os.environ.get("AI_APP_URL"))
        raise NotImplementedError(f"Invalid model name: {mode}")

    @staticmethod
    def get_llm(mode: str, model: str) -> OpenAI | Anthropic | None:
        """Get LLM object based on the mode and model"""
        if mode == "gpt":
            return OpenAI(openai_api_key=os.environ.get("OPENAI_API_KEY"), model=model)
        if mode == "claude":
            return Anthropic(model=model, api_key=os.environ.get("ANTHROPIC_API_KEY"))
        if mode in {"perplexity", "ollama"}:
            return None
        raise NotImplementedError(f"Invalid model name: {mode}")


class EmbeddingFactory:
    @staticmethod
    def get_embedding(mode: str, model: str) -> Embeddings | None:
        """Get embedding object based on the mode"""
        if mode == "gpt":
            return OpenAIEmbeddings(openai_api_key=os.environ.get("OPENAI_API_KEY"))
        if mode == "ollama":
            return OllamaEmbeddings(base_url=os.environ.get("AI_APP_URL"), model=model)
        if mode == "cohere":
            return CohereEmbeddings(
                cohere_api_key=os.environ.get("COHERE_API_KEY"),
                model=model,
            )
        if mode in {"claude", "perplexity"}:
            return None
        raise NotImplementedError(f"Embedding not implemented for mode: {mode}")

Discussion