😺

MLflowと一緒にローカルRAGを使ってみる

に公開

はじめに

これまでMLflowを使ってローカルでLLMを使い、評価する方法を調べてきました。

今日はRAGを使ってみます。

RAG

社内などドメイン固有の情報をソースとしてLLMの回答に組み込む手法と理解しています。
ChatGPTに聞いてみると以下の回答が返ってきました。

RAG(Retrieval-Augmented Generation)は、情報検索と 生成AI(LLMなど) を組み合わせた手法です。

  • Retrieval(検索):外部データベースなどから関連情報を取得
  • Augmented Generation(生成強化):取得した情報をもとにLLMが回答を生成

今回は以下の参照情報をベースに作業を進めます。

今回のコードは以下レポジトリにあります。

概要

今回作成するシステムはざっくり以下のイメージです。

  • インターネットから集めたTextコンテンツを分割して、FAISS で管理する
  • LLMに質問をした時にベクトル情報を使って関連する情報を取得する
  • 取得した情報と質問情報を合わせて最終的な回答を生成する

本記事で学べること

  • MLflowとLangChainを使ってRAGを構築する方法
  • RAGシステムに入力するためにドキュメントをスクレイピングする方法
  • 複雑な質問に回答するためのRAGモデルのデプロイ方法と使い方
  • LLMを直接使った場合とRAGとの応答の実際的な意味合いの違い

環境

本記事の動作確認は以下の環境で行いました。

  • MacBook Pro
  • 14 インチ 2021
  • チップ:Apple M1 Pro
  • メモリ:32GB
  • macOS:15.5(24F74)

FAISSの導入

ベクターデータを保存するためのデータベースとして、FAISSを使います。FAISSはシンプルな無料で使える Meta 製のフレームワークです。

依存関係のインストール

以下コマンドで依存関係をインストールします。
参照記事ではバージョンを明示的に指定していますが、今回は最新にすることを目指してみます。

pipenv install beautifulsoup4 faiss-cpu langchain langchain-community langchain-openai openai tiktoken

beautifulsoup4 はHTMLやXMLからデータを抽出するためのライブラリです。

mlflow_rag.py を作り、以下のコードを実装します。

import os
import shutil  # for deleting temporary files
import tempfile

import mlflow
import requests
from bs4 import BeautifulSoup  # HTMLのパースに使う
from langchain.chains import RetrievalQA  # データベースからの検索に使う
from langchain.document_loaders import TextLoader  # テキストファイルを読み込む
from langchain.text_splitter import CharacterTextSplitter  # テキストを分割する
from langchain.vectorstores import FAISS  # ベクトルデータベースを作る
from langchain_openai import OpenAI, OpenAIEmbeddings  # OpenAI APIを使う

assert (
    "OPENAI_API_KEY" in os.environ
), "Please set the OPENAI_API_KEY environment variable."

連邦ドキュメントのスクレイピング

RAGシステムで使用するために、連邦政府の文書ページからコンテンツをスクレイピングする方法を説明します。特定のウェブページセクションから議事録(トランスクリプト)を抽出することに焦点を当て、それをRAGモデルに入力します。このプロセスは、RAGシステムに関連する外部データを提供する上で重要です。

mlflow_rag.py に以下を実行する fetch_federal_document 関数を追加します。

  • fetch_federal_document 関数は、特定の連邦政府文書の議事録(トランスクリプト)をスクレイピングして返すように設計されています。
  • この関数は次の2つの引数を取ります、url(ウェブページのURL)、div_class(トランスクリプトを含むdiv要素のクラス名)
  • 関数はウェブリクエストを処理し、HTMLコンテンツを解析し、目的のトランスクリプトテキストを抽出します。
def fetch_federal_document(url, div_class):  # noqa: D417
    """
    Scrapes the transcript of the Act Establishing Yellowstone National Park from the given URL.

    Args:
    url (str): URL of the webpage to scrape.

    Returns:
    str: The transcript text of the Act.
    """
    # Sending a request to the URL
    response = requests.get(url)
    if response.status_code == 200:
        # Parsing the HTML content of the page
        soup = BeautifulSoup(response.text, "html.parser")

        # Finding the transcript section by its HTML structure
        transcript_section = soup.find("div", class_=div_class)
        if transcript_section:
            transcript_text = transcript_section.get_text(separator="\n", strip=True)
            return transcript_text
        else:
            return "Transcript section not found."
    else:
        return f"Failed to retrieve the webpage. Status code: {response.status_code}"

ドキュメントの取得とFAISSデータベースの生成

以下の作業を実施します。

  1. ドキュメントを取得する関数 fetch_and_save_documents を定義します。
    • 特定のURLからドキュメントを取得します
    • この関数はドキュメントを取得するURLのリストとファイルのパスを引数として受け付けます
    • 各URLからドキュメントを取得し、指定したパスのファイルに保存します
  2. FAISSデータベースを生成する関数 create_faiss_database を定義します。
    • fetch_and_save_documents を使って取得したドキュメントを使ってFAISSデータベースを作ります
    • TextLoader を使ってドキュメントを読み込み、CharacterTextSplitter を使ってドキュメントを分割します
    • 特定のディレクトに保存された効率的な類似性検索ができるFAISSデータベースを返します
def fetch_and_save_documents(url_list, doc_path):
  """
  Fetches documents from given URLs and saves them to a specified file path.

  Args:
      url_list (list): List of URLs to fetch documents from.
      doc_path (str): Path to the file where documents will be saved.
  """
  for url in url_list:
      document = fetch_federal_document(url, "col-sm-9")
      with open(doc_path, "a") as file:
          file.write(document)


def create_faiss_database(document_path, database_save_directory, chunk_size=500, chunk_overlap=10):
  """
  Creates and saves a FAISS database using documents from the specified file.

  Args:
      document_path (str): Path to the file containing documents.
      database_save_directory (str): Directory where the FAISS database will be saved.
      chunk_size (int, optional): Size of each document chunk. Default is 500.
      chunk_overlap (int, optional): Overlap between consecutive chunks. Default is 10.

  Returns:
      FAISS database instance.
  """
  # Load documents from the specified file
  document_loader = TextLoader(document_path)
  raw_documents = document_loader.load()

  # Split documents into smaller chunks with specified size and overlap
  document_splitter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
  document_chunks = document_splitter.split_documents(raw_documents)

  # Generate embeddings for each document chunk
  embedding_generator = OpenAIEmbeddings()
  faiss_database = FAISS.from_documents(document_chunks, embedding_generator)

  # Save the FAISS database to the specified directory
  faiss_database.save_local(database_save_directory)

  return faiss_database

環境を構築する

FAISSデータベースを作成し、RAGアプリケーションをセットアップします。この作業によりすべての文書が1か所に統合され、MLflow対応のRAGアプリケーションで検索に使用できるFAISSデータベースが準備されます。

  1. 一時ディレクトリの作成
    • tempfile.mkdtemp() を使って一時ディレクトリ作り、ここで作業をします
  2. ドキュメントのパスとFAISSインデックスのディレクトリ
    • 取得したドキュメントとFAISSデータベースは作成した一時ディレクトリに保存します
  3. ドキュメントの取得
    • ドキュメントを取得するurlのリストを定義します
    • fetch_and_save_documents 関数を使ってドキュメントを取得します
  4. FAISSデータベースの作成
    • デフォルトのchunk_sizechunk_overlap を使って create_faiss_database 関数を呼び出します
    • ここで作成する vector_db を使って類似性検索をします
temporary_directory = tempfile.mkdtemp()

doc_path = os.path.join(temporary_directory, "docs.txt")
persist_dir = os.path.join(temporary_directory, "faiss_index")

url_listings = [
  "https://www.archives.gov/milestone-documents/act-establishing-yellowstone-national-park#transcript",
  "https://www.archives.gov/milestone-documents/sherman-anti-trust-act#transcript",
]

fetch_and_save_documents(url_listings, doc_path)

vector_db = create_faiss_database(doc_path, persist_dir)

MLflowでRetrievalQAチェーンとログを可視化をする

「RetrievalQAチェーン」とは、検索機能を持ったチャットボットを実現するためのプロセスのことを指しています。ここでは、RetrievalQAチェーンをMLflowと連携します。MLflowでログを保存することで、モデルのパフォーマンスを評価し、モデルの改善に役立てることができます。

  1. RetrievalQAチェーンを初期化する
    • OpenAIのLLMを使ってRetrievalQAチェーンと先ほど作成したAFISSデータベースを初期化します。
    • このチェーンは、OpenAIのモデルを使用して応答を作成し、FAISSで情報を検索します。
  2. 検索のためのローダ load_retriever を実装します
    • FAISSデータベースの検索機能を使えるようにロードします
  3. Mlflowでログを保存する
    • mlflow.langchanin.log_model を使って、RetrievalQAチェーンをMLflowにログします
    • archive_pathloader_fn 及びFAISSが保存されている場所を示す parsist_dir を指定する必要があります
mlflow.set_experiment("Legal RAG")

retrievalQA = RetrievalQA.from_llm(llm=OpenAI(), retriever=vector_db.as_retriever())


# Log the retrievalQA chain
def load_retriever(persist_directory):
  embeddings = OpenAIEmbeddings()
  vectorstore = FAISS.load_local(
      persist_directory,
      embeddings,
      allow_dangerous_deserialization=True,  # This is required to load the index from MLflow
  )
  return vectorstore.as_retriever()


with mlflow.start_run() as run:
  model_info = mlflow.langchain.log_model(
      retrievalQA,
      name="retrieval_qa",
      loader_fn=load_retriever,
      persist_dir=persist_dir,
  )

ローカル用の変更をする

さて、ここまでMLflowのチュートリアルに従って作業をしてきました。ローカルで動作させるためには、ベクトル化の処理とLLMモデルをローカル様に変更します。以下の変更を加えます。

  • ローカルでHuggingFaceEmbeddingを使うように変更
  • LLMモデルはLMStudioで動かしているgemma3を使う

また、このタイミングでdeprecatedなモジュールの更新も行いました。

以下のインストールが必要です。

pipenv install langchain-huggingface

以下が mlflow_rag.py の最終的な状態です。Streamlitのアプリから利用するための関数 create_retrieval_qa も追加してあります。

import os
import tempfile
import warnings

import mlflow
import requests
from bs4 import BeautifulSoup  # HTMLのパースに使う
from langchain._api import LangChainDeprecationWarning
from langchain.chains import RetrievalQA  # データベースからの検索に使う
from langchain.text_splitter import CharacterTextSplitter  # テキストを分割する
from langchain_community.document_loaders import (
    TextLoader,  # テキストファイルを読み込む
)
from langchain_community.vectorstores import FAISS  # ベクトルデータベースを作る
from langchain_huggingface import HuggingFaceEmbeddings  # ローカル埋め込み用
from langchain_openai import ChatOpenAI  # LM Studio用

# LangChainの非推奨警告を抑制
warnings.filterwarnings("ignore", category=LangChainDeprecationWarning)

# HuggingFaceトークナイザーの並列処理警告を抑制
os.environ["TOKENIZERS_PARALLELISM"] = "false"


def fetch_federal_document(url, div_class):  # noqa: D417
    """
    Scrapes the transcript of the Act Establishing Yellowstone National Park from the given URL.

    Args:
    url (str): URL of the webpage to scrape.

    Returns:
    str: The transcript text of the Act.
    """
    # Sending a request to the URL
    response = requests.get(url)
    if response.status_code == 200:
        # Parsing the HTML content of the page
        soup = BeautifulSoup(response.text, "html.parser")

        # Finding the transcript section by its HTML structure
        transcript_section = soup.find("div", class_=div_class)
        if transcript_section:
            transcript_text = transcript_section.get_text(separator="\n", strip=True)
            return transcript_text
        else:
            return "Transcript section not found."
    else:
        return f"Failed to retrieve the webpage. Status code: {response.status_code}"


def fetch_and_save_documents(url_list, doc_path):
    """
    Fetches documents from given URLs and saves them to a specified file path.

    Args:
        url_list (list): List of URLs to fetch documents from.
        doc_path (str): Path to the file where documents will be saved.
    """
    for url in url_list:
        document = fetch_federal_document(url, "col-sm-9")
        with open(doc_path, "a") as file:
            file.write(document)


def create_faiss_database(
    document_path, database_save_directory, chunk_size=500, chunk_overlap=10
):
    """
    Creates and saves a FAISS database using documents from the specified file.

    Args:
        document_path (str): Path to the file containing documents.
        database_save_directory (str): Directory where the FAISS database will be saved.
        chunk_size (int, optional): Size of each document chunk. Default is 500.
        chunk_overlap (int, optional): Overlap between consecutive chunks. Default is 10.

    Returns:
        FAISS database instance.
    """
    # Load documents from the specified file
    document_loader = TextLoader(document_path)
    raw_documents = document_loader.load()

    # Split documents into smaller chunks with specified size and overlap
    document_splitter = CharacterTextSplitter(
        chunk_size=chunk_size, chunk_overlap=chunk_overlap
    )
    document_chunks = document_splitter.split_documents(raw_documents)

    # HuggingFace埋め込みモデルを使用
    embedding_generator = HuggingFaceEmbeddings(
        model_name="sentence-transformers/all-MiniLM-L6-v2",
        model_kwargs={"device": "cpu"},  # CPUを使用(GPUがある場合は'cuda')
    )
    faiss_database = FAISS.from_documents(document_chunks, embedding_generator)

    # Save the FAISS database to the specified directory
    faiss_database.save_local(database_save_directory)

    return faiss_database


# Log the retrievalQA chain
def load_retriever(persist_directory):
    # 同じHuggingFace埋め込みモデルを使用
    embeddings = HuggingFaceEmbeddings(
        model_name="sentence-transformers/all-MiniLM-L6-v2",
        model_kwargs={"device": "cpu"},
    )
    vectorstore = FAISS.load_local(
        persist_directory,
        embeddings,
        allow_dangerous_deserialization=True,  # This is required to load the index from MLflow
    )
    return vectorstore.as_retriever()


# チャットアプリから使えるように一連の作業を関数化する
def create_retrieval_qa():
    """
    RetrievalQAチェーンを作成して返す

    Returns:
        RetrievalQA: 設定済みのRetrievalQAチェーン
    """
    # ドキュメントの取得と保存
    temporary_directory = tempfile.mkdtemp()

    doc_path = os.path.join(temporary_directory, "docs.txt")
    persist_dir = os.path.join(temporary_directory, "faiss_index")

    url_listings = [
        "https://www.archives.gov/milestone-documents/act-establishing-yellowstone-national-park#transcript",
        "https://www.archives.gov/milestone-documents/sherman-anti-trust-act#transcript",
    ]

    fetch_and_save_documents(url_listings, doc_path)

    vectorstore = create_faiss_database(doc_path, persist_dir)

    # LLMの設定
    llm = ChatOpenAI(
        base_url="http://localhost:1234/v1",
        api_key=None,
        temperature=0.7,
        model_name="google/gemma-3-12b",
    )

    # RetrievalQAチェーンの作成
    qa_chain = RetrievalQA.from_chain_type(
        llm=llm,
        chain_type="stuff",
        retriever=vectorstore.as_retriever(),
        return_source_documents=True,
    )

    return qa_chain, persist_dir


# MLflowサーバーのURIを設定
mlflow.set_tracking_uri("http://localhost:5001")
mlflow.set_experiment("Legal RAG")
retrievalQA, persist_dir = create_retrieval_qa()


with mlflow.start_run() as run:
    # 依存関係を明示的に指定
    # mlflow.utils.environment: Encountered an unexpected error while inferring pip requirements の警告を抑制するため
    pip_requirements = [
        "langchain==0.3.25",
        "langchain-community",
        "langchain-openai",
        "langchain-huggingface",
        "sentence-transformers",
        "faiss-cpu",
        "beautifulsoup4",
        "requests",
        "pydantic==2.11.7",
        "cloudpickle==3.1.1",
    ]

    model_info = mlflow.langchain.log_model(
        retrievalQA,
        name="retrieval_qa",
        loader_fn=load_retriever,
        persist_dir=persist_dir,
        pip_requirements=pip_requirements,  # 明示的に指定
    )

以下のコマンドでアプリケーションを生成します。

pipenv run streamlit run ./mlflow_rag.py

MLflowにアプリケーションが作られました!

RAGモデルをテストする

これまででMLflowにモデルをロードできました。pyfunc としてモデルを取得して質問することができるのですが、Streamlitとうまく連携ができなかったので、qa_chainを返す関数を作成して、それを使ってテストします。

以下で作成したチャットアプリケーションを使って、モデルをテストします。

Streamlitアプリのコードは以下です。完全なコードはGitHubリポジトリを参照してください。

import os
import sys

import mlflow
import streamlit as st
from langchain_openai import ChatOpenAI

# ローカルのPyTorchを無効化
# これをしないと `RuntimeError: Tried to instantiate class '__path__._path', but it does not exist! Ensure that it is registered via torch::class_` みたいなエラーになる
# streamlitがソースコードのファイル変更を監視する時に起きるエラーらしく、検索対象を無効化してしまえば良い
sys.modules["torch.classes"] = None

# mlflow_rag.pyからRetrievalQA作成関数をインポート
from mlflow_rag import create_retrieval_qa

st.title("🦜🔗 RAG Quickstart App")

# Set the active experiment
mlflow.set_experiment("Legal RAG")


# RetrievalQAチェーンを初期化(キャッシュして一度だけ実行する)
@st.cache_resource
def initialize_retrieval_qa():
    """RetrievalQAチェーンを初期化する"""
    return create_retrieval_qa()


# RetrievalQAチェーンを取得
retrieval_qa, _ = initialize_retrieval_qa()


@mlflow.trace(
    name="rag_llm_call",
    attributes={"model": "gemma-3-12b", "source": "local", "type": "RAG"},
)
def generate_response_with_rag(input_text):
    # RAGを使用してレスポンスを生成
    # RetrievalQAチェーンを使用して回答を生成
    result = retrieval_qa.invoke({"query": input_text})
    output = result["result"]

    # 参照されたドキュメントも表示(オプション)
    if "source_documents" in result:
        with st.expander("参照されたドキュメント"):
            for i, doc in enumerate(result["source_documents"]):
                st.write(f"**ソース {i+1}:**")
                st.write(
                    doc.page_content[:500] + "..."
                    if len(doc.page_content) > 500
                    else doc.page_content
                )

    st.info(output)
    return output


with st.form("my_form"):
    text = st.text_area(
        "Enter text:",
        "What does the document say about trespassers?",
    )
    submitted = st.form_submit_button("Submit")
    if submitted:
        generate_response_with_rag(text)

以下コマンドで実行します。

pipenv run streamlit run ./streamlit_app_with_rag.py

Streamlitのアプリケーションが立上が得るので質問してみると、無事それっぽい返信が返ってきました!

ちなみにRAGなしに聞いてみるとこんな感じです。「ドキュメントを見せてくれないとわかんないよ!」みたいなことを言っていますね。

以上でRAGの動作確認は終わりです!

おわりに

今回は、MLflow と LangChain を使ってローカルRAGアプリケーションを作ってみました。RetrievalQAをStreamlitアプリに連携する際にちょっとつまりましたが、最終的には動作確認できました。

FAISS を使ったベクターデータの管理、HuggingFace を使ったベクター埋め込み、BeautifulSoup を使ったデータの取得など、これまで触れる機会の少なかった技術を実践的に学ぶことができ、非常に有益でした。RAGは、ドメイン固有の知識を活用することで、LLMの回答精度を大幅に向上させる強力な技術であることが確認できました。

参照元のチュートリアルには、RAGに関する説明の続きがあります。ちょっと長いので別途以下にまとめました。

Discussion