LLMファクトリークラスを実装してLangchainのモデル変更を簡単にした話
tips: 9割Claudeで記述してます。すごく便利です!
はじめに
大規模言語モデル(LLM)の急速な発展により、多くの開発者が複数のLLMプロバイダーやモデルを使用するようになっています。しかし、これらの異なるAPIや設定を管理することは、しばしば煩雑になります。この記事では、この課題に対処するための効果的なソリューション、LLMファクトリークラスの実装について解説します。
LLMファクトリークラスの概要
LLMファクトリークラスは、ファクトリーパターンを応用した設計手法です。このアプローチにより、異なるLLMプロバイダーやモデルを統一的なインターフェースで扱うことができ、コードの可読性と保守性が大幅に向上します。
実装の詳細
今回のLLMファクトリークラスは、主に2つのクラスで構成されています:
-
LLMFactory
: チャットモデルとLLMオブジェクトの生成を担当 -
EmbeddingFactory
: 埋め込みモデルの生成を担当
LLMFactory
class LLMFactory:
@staticmethod
def get_chat(mode: str, model: str) -> BaseChatModel:
# チャットモデルの生成ロジック
@staticmethod
def get_llm(mode: str, model: str) -> OpenAI | Anthropic | None:
# LLMオブジェクトの生成ロジック
LLMFactory
クラスは、get_chat
とget_llm
という2つの静的メソッドを提供します。これらのメソッドは、指定されたモードとモデルに基づいて適切なチャットモデルまたはLLMオブジェクトを返します。
EmbeddingFactory
class EmbeddingFactory:
@staticmethod
def get_embedding(mode: str, model: str) -> Embeddings | None:
# 埋め込みモデルの生成ロジック
EmbeddingFactory
クラスは、get_embedding
メソッドを通じて、指定されたモードとモデルに基づいて適切な埋め込みモデルを返します。
主要な特徴
-
モジュール性: 各プロバイダーやモデルの初期化ロジックがカプセル化されており、新しいプロバイダーの追加が容易になっています。
-
環境変数の活用: APIキーなどの機密情報は環境変数から取得し、セキュリティを確保しています。
-
エラーハンドリング: サポートされていないモードに対しては
NotImplementedError
を発生させ、適切なエラー処理を促します。 -
柔軟性: チャット、LLM、埋め込みなど、異なるタイプのモデルに対応し、プロジェクトの要件に応じて柔軟に使用できます。
使用例
以下は、LLMファクトリークラスを使用する簡単な例です:
# チャットモデルの取得
chat_model = LLMFactory.get_chat("gpt", "gpt-3.5-turbo")
# LLMオブジェクトの取得
llm = LLMFactory.get_llm("claude", "claude-2")
# 埋め込みモデルの取得
embedding_model = EmbeddingFactory.get_embedding("gpt", "text-embedding-ada-002")
この例では、異なるプロバイダーのモデルを統一的なインターフェースで簡単に取得できることがわかります。
まとめと今後の展望
LLMファクトリークラスの実装により、異なるLLMプロバイダーやモデルを効率的に管理することが可能になります。これにより、コードの可読性が向上し、新しいモデルやプロバイダーの追加も容易になります。また、環境変数を利用することでセキュリティも確保されています。
今後のNLP開発において、このようなファクトリークラスの活用は、効率的で柔軟なシステム設計に大きく貢献するでしょう。さらに、この設計パターンを拡張し、モデルのバージョン管理や性能モニタリングなどの機能を追加することで、より強力なLLM管理システムを構築することも可能です。
LLMの急速な進化に伴い、このようなツールの重要性はますます高まっていくと予想されます。開発者の皆さんも、ぜひ自身のプロジェクトでLLMファクトリークラスの導入を検討してみてください。
コード全体
import os
from langchain_anthropic import Anthropic, ChatAnthropic
from langchain_community.chat_models import ChatOllama, ChatPerplexity
from langchain_community.embeddings import CohereEmbeddings, OllamaEmbeddings
from langchain_core.embeddings import Embeddings
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_openai import ChatOpenAI, OpenAI, OpenAIEmbeddings
class LLMFactory:
@staticmethod
def get_chat(mode: str, model: str) -> BaseChatModel:
"""Get chat object based on the mode and model"""
if mode == "gpt":
return ChatOpenAI(
openai_api_key=os.environ.get("OPENAI_API_KEY"),
model=model,
)
if mode == "claude":
return ChatAnthropic(
model=model,
anthropic_api_key=os.environ.get("ANTHROPIC_API_KEY"),
)
if mode == "perplexity":
return ChatPerplexity(
model=model,
api_key=os.environ.get("PERPLEXITY_API_KEY"),
)
if mode == "ollama":
return ChatOllama(model=model, base_url=os.environ.get("AI_APP_URL"))
raise NotImplementedError(f"Invalid model name: {mode}")
@staticmethod
def get_llm(mode: str, model: str) -> OpenAI | Anthropic | None:
"""Get LLM object based on the mode and model"""
if mode == "gpt":
return OpenAI(openai_api_key=os.environ.get("OPENAI_API_KEY"), model=model)
if mode == "claude":
return Anthropic(model=model, api_key=os.environ.get("ANTHROPIC_API_KEY"))
if mode in {"perplexity", "ollama"}:
return None
raise NotImplementedError(f"Invalid model name: {mode}")
class EmbeddingFactory:
@staticmethod
def get_embedding(mode: str, model: str) -> Embeddings | None:
"""Get embedding object based on the mode"""
if mode == "gpt":
return OpenAIEmbeddings(openai_api_key=os.environ.get("OPENAI_API_KEY"))
if mode == "ollama":
return OllamaEmbeddings(base_url=os.environ.get("AI_APP_URL"), model=model)
if mode == "cohere":
return CohereEmbeddings(
cohere_api_key=os.environ.get("COHERE_API_KEY"),
model=model,
)
if mode in {"claude", "perplexity"}:
return None
raise NotImplementedError(f"Embedding not implemented for mode: {mode}")
Discussion