💬

【LLM × LangChain】Notion APIを用いた簡易RAGアプリの実装

2024/12/24に公開

はじめに

前から生成AIについては興味があったものの、特に「何か作りたい」というアイデアも思い浮かばなかった中、何やらRAGというものがあるらしいということを知り、早速簡単なアプリを作ってみました。

RAG(ラグ)とは「Retrieval Augmented Generation(検索拡張生成)」の略称で、生成AIをサポートする技術の一種です。ユーザーからの質問や指示に対して、外部のデータから情報を検索して回答を生成する仕組みがRAGと呼ばれます。

引用元|Udemyメディア RAGとは?LLMの欠点を補う仕組みとメリット・活用方法を解説

LLMで真っ先に思い浮かぶのはChatGPTだと思いますが、ChatGPTのAPIは、トークンの利用量に応じて課金されるシステムになっています。

ただ今回は、一銭も使用しない無料縛りかつNotionでデータを管理していると仮定し、どこまで正確に回答できるかに挑みました。

作ったもの

Notionにconnectome.designの人材募集要項に関する記事を作成し、RAGアプリを通じて正しく回答できるようにすることを目標としました。

外部データ (Notion)

外部データなしLLM


designのワードに引っ張られて、デザイン系の回答をそれっぽく返しています。なお肝心な質問に対する回答はありません。

外部データありLLM


Notionの内容に沿った回答がされています。

機能追加などに伴いコードが変更される可能性がありますが、本記事で作成したアプリは一応githubで公開しています。
https://github.com/seita-f/RAG-Notion-App

おおまかな流れ

  1. Notionからコンテンツを取得
  2. 取得したコンテンツをベクトル化しデータベースに保存
  3. ユーザーからの質問に対し、保存されたデータベース内で関連データを検索し回答を生成

ディレクトリ構造

プロジェクトのディレクトリ構造は以下になります。

.
├── chroma_db 
├── .env
├── embedding.py
├── llm.py
├── notion-api
│   ├── notion_contents.json
│   └── retrieve_data.py
├── requirements.txt
├── ui
    └── main.py

また安全にアプリを管理するため、.envでAPI等を管理していきます。それぞれのAPIの取得については後述してあります。

.env
NOTION_API_URL=https://api.notion.com/v1/blocks
NOTION_VERSION=2022-06-28
NOTION_API_KEY=<Your Notion API Key>
NOTION_PAGE_IDS="
<Your Page ID>, 
<Your Page ID>, 
"

HUGGING_FACE_API_KEY=<Your Hugging Face API KEY>

Notion API

Notionからコンテンツを取得するに当たって、Notionのインテグレーショントークンの取得、また作成したインテグレーションをページとコネクトする必要があります。
これについては、以下の記事で丁寧に解説されているので、今回は割愛します。
https://temp.co.jp/blog/2024-01-21-notion-integration-connect

ページIDについては、取得したいNotionのURLより取り出せます。
notion.so/<タイトル>-<ページID>

コンテンツ取得

最初に、LLMで回答を生成するために必要な独自のドメイン知識をNotionから引っ張ってきます。
事前に.envで定義した、コンテンツを取得したいページのID (URLの一部)を読み込み、それぞれのページからデータを取得していきます。

流れとしては、

  1. ページ内のブロックを取得
  2. ブロックに子要素(リスト等々)がある場合、それらを取得
  3. テキストデータを正規化
  4. 取得したコンテンツをdictに追加
  5. コンテンツをjsonに書き込み

となります。

retrieve_data.py
import requests
import json
import unicodedata
import os
import sys
from dotenv import load_dotenv

def get_all_blocks(page_id, headers, notion_api_url):
    all_blocks = []
    url = f"{notion_api_url}/{page_id}/children"
    while url:
        response = requests.get(url, headers=headers)
        if response.status_code == 200:
            data = response.json()
            all_blocks.extend(data.get("results", []))
            
            if data.get("has_more"):
                url = f"{notion_api_url}/{page_id}/children?start_cursor={data.get('next_cursor')}"
            else:
                url = None
        else:
            print(f"Error: {response.status_code}, {response.text}")
            break
    return all_blocks

# テキストデータの正規化
def normalize_text_data(content_list):
    normalized_content_list = [unicodedata.normalize('NFKC', text).strip().lower() for text in content_list]   
    return normalized_content_list 

def get_page_title(page_id, headers):
    url = f"https://api.notion.com/v1/pages/{page_id}"
    response = requests.get(url, headers=headers)
    if response.status_code == 200:
        page_data = response.json()
        properties = page_data.get("properties", {})
        for prop in properties.values():
            if prop.get("type") == "title":
                title_texts = prop.get("title", [])
                return "".join([text["plain_text"] for text in title_texts])
        return "Untitled"
    else:
        print(f"Error: {response.status_code}, {response.text}")
        return None

def extract_text_from_blocks(blocks, headers, notion_api_url):
    content_list = []
    for block in blocks:
        block_type = block.get("type")
        
        # テキストブロック
        if block_type in ["paragraph", "heading_1", "heading_2", "heading_3", "bulleted_list_item", "numbered_list_item"]:
            rich_texts = block[block_type].get("rich_text", [])
            content_list.extend([text["text"]["content"] for text in rich_texts if "text" in text])
        
        # テーブルブロック
        elif block_type == "table":
            table_id = block["id"]
            child_blocks = get_all_blocks(table_id, headers, notion_api_url)
            for row in child_blocks:
                if row.get("type") == "table_row":
                    row_cells = row["table_row"]["cells"]
                    row_text = ["".join([cell["text"]["content"] for cell in cell_texts if "text" in cell]) for cell_texts in row_cells]
                    content_list.append("\t".join(row_text))  # セルをタブ区切りで結合
        
        # 子ブロック 
        if block.get("has_children"):
            child_id = block["id"]
            child_blocks = get_all_blocks(child_id, headers, notion_api_url)
            content_list.extend(extract_text_from_blocks(child_blocks, headers, notion_api_url))
    
    return content_list


def main():
    try:
        # .env
        load_dotenv()
        NOTION_API_KEY = os.getenv('NOTION_API_KEY')
        NOTION_API_URL = os.getenv('NOTION_API_URL')
        NOTION_VERSION = os.getenv('NOTION_VERSION')
        PAGE_IDS = os.getenv('NOTION_PAGE_IDS')
        PAGE_IDS_LIST = []
        if PAGE_IDS:
            PAGE_IDS_LIST = [page_id.strip() for page_id in PAGE_IDS.split(",") if page_id.strip()]
        else:
            print("NOTION_PAGE_IDS is not set.")

        # ヘッダー設定
        headers = {
            "Authorization": f"Bearer {NOTION_API_KEY}",
            "Notion-Version": NOTION_VERSION
        }

        all_content = {}
        for page_id in PAGE_IDS_LIST:

            # ページのタイトルを取得
            title = get_page_title(page_id, headers)

            # ページのすべてのブロックを取得
            blocks = get_all_blocks(page_id, headers, NOTION_API_URL)

            # ブロックからテキストを抽出
            page_content = extract_text_from_blocks(blocks, headers, NOTION_API_URL)
            normalized_page_content = normalize_text_data(page_content)

            # コンテンツを1つの文字列に結合
            full_text = "\n".join(normalized_page_content)

            # タイトルをキーにして保存
            all_content[title] = full_text
            print(f"Retrieved the contents from {title}")

        # 取得結果をJSON形式で保存
        with open("notion_contents.json", "w", encoding="utf-8") as file:
            json.dump(all_content, file, ensure_ascii=False, indent=4)

        print("Content saved to notion_contents.json")

    except Exception as e:
        print(f"Error: {e}")

# run
if __name__ == "__main__":
    main()

書き込まれたjsonは {"ページタイトル": "コンテンツ", ...}で保存されています。

{
    "RAG-Test": "connectome.designでは以下のような人材を募集しています\nコンサルタント\n機械学習・ai・iot等最新技術と社会における課題、企業における課題の解決に結びつけることができるコンサルタントを募集しています。ai/iot/機械学習の知識。一人でお客様と交渉し、相談に応じて解決策が提示出来る。 実はプログラミングも出来る。\nリサーチャー\n機械学習・人工知能・人工意識・人工生命等の研究開発を実施出来る研究者を募集しています。英語論文の読解と実装・実験が出来るスキル。一流論文誌・国際会議で採択される論文執筆力– 三度の飯より研究が好き。どうしてもネコ型ロボットを作りたい。\nソフトウエアエンジニア\n実用に耐えうるソフトウエアを超効率的に開発出来るエンジニアを募集しています。プログラミングに対する哲学を持っている。機械学習・ai・iotの知識を持っている。新しいこと・刺激的なことしかやりたくない。キーボードは友達。\n技術営業\nai案件の管理を担って頂く営業を募集しています。人とのコミュニケーションが得意。スケジュール管理が得意。議論を仕切るのが得意。"
}

ベクトル化

次に、取得したコンテンツを、分割&ベクトル化し、その結果をベクトルデータベースに保存していきます。

イメージとしては、

文字列 ベクトル形式
connectome 0.015325653563, 0.00553566422...
design 0.021049759643, 0.00873948384...
0.083064738471, 0.00474757573...
0.08026833566, 0.00423543565...
.. ..

のようになります。

このように、ベクトル化(数値化)することによって、ユーザーからの質問との距離(類似度)を計算できるようになり、最も関連性の高いコンテンツを特定し、回答の生成に活用します。

ベクトル化には、all-mpnet-base-v2モデルを、ベクトルデータベースにはおChromaを今回は使用していきます。

両方とも、HuggingFaceのAPIを使用して引っ張ってきます。
Hugging FaceのAPIの取得に関しては、以下の記事で丁寧に解説されています。
https://zenn.dev/protoout/articles/73-hugging-face-setup

また、ベクトル化を行う前に、テキストをDocument型に変換します。これにより、テキストデータとメタデータ(例: コンテンツのキーやソース情報)を統一フォーマットで管理できるようになったり、他にもLangChainの恩恵を受けられるようになります。

ここでの全体の流れとしては、以下になります。

  1. JSONデータの読み込み
  2. Document形式への変換
  3. テキストの分割
  4. 埋め込み生成 (ベクトル化)
  5. データベースへの保存
embedding.py
import json
import os
from dotenv import load_dotenv

from langchain_text_splitters import CharacterTextSplitter
from langchain.schema import Document  
from langchain_chroma import Chroma 
from langchain_huggingface import HuggingFaceEmbeddings


def load_json_data(file_path):
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"File not found: {file_path}")
    with open(file_path, "r", encoding="utf-8") as f:
        return json.load(f)

def create_documents_from_json(json_data):
    documents = []
    for key, value in json_data.items():
        if isinstance(value, str):
            # オブジェクトを作成 & ページタイトルをメタデータとして追加 
            documents.append(Document(page_content=value, metadata={"key": key}))
        else:
            print(f"Unexpected data type for key {key}: {type(value)}")
    return documents

def clean_text(text):
    return " ".join(line.strip() for line in text.splitlines() if line.strip())

def split_and_clean_documents(documents, chunk_size=1000, chunk_overlap=10, separator='\n'):
    text_splitter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
    split_docs = text_splitter.split_documents(documents)
    cleaned_docs = [
        Document(
            page_content=clean_text(doc.page_content),
            metadata=doc.metadata  
        )
        for doc in split_docs
    ]
    return cleaned_docs

def create_and_persist_chroma_db(cleaned_docs, embedding_model_name, persist_directory):
    # 埋め込み関数を作成
    embedding_function = HuggingFaceEmbeddings(model_name=embedding_model_name)
    
    # ディレクトリを作成(存在しない場合)
    if not os.path.exists(persist_directory):
        os.makedirs(persist_directory)
    
    # Chroma データベースを作成
    db = Chroma.from_documents(
        cleaned_docs,
        embedding_function,
        persist_directory=persist_directory,
    )
    # db.persist()
    print(f"Database saved successfully to disk at {persist_directory}")


def main():
    try:
        # .env
        load_dotenv()
        HUGGING_FACE_API_KEY = os.getenv('HUGGING_FACE_API_KEY')

        if HUGGING_FACE_API_KEY:
            os.environ["HUGGINGFACEHUB_API_TOKEN"] = HUGGING_FACE_API_KEY
            print("Hugging Face API Key set successfully.")
        else:
            raise ValueError("HUGGING_FACE_API_KEY is not set in the .env file.")

        PERSIST_DIRECTORY_PATH = "./chroma_db"
        CONTENTS_JSON_PATH = "notion-api/notion_contents.json"

        # JSONファイルを読み込む
        json_data = load_json_data(CONTENTS_JSON_PATH)
        
        # JSONデータからDocumentを作成
        documents = create_documents_from_json(json_data)

        # DEBUG:
        # for doc in documents:
        #     print(f"Document content: {doc.page_content}")
        
        # ドキュメントを分割してクリーンアップ
        cleaned_docs = split_and_clean_documents(documents)
        
        # Chromaデータベースを作成して永続化
        create_and_persist_chroma_db(
            cleaned_docs,
            embedding_model_name="all-mpnet-base-v2",
            persist_directory=PERSIST_DIRECTORY_PATH,
        )
    except Exception as e:
        print(f"Error: {e}")


if __name__ == "__main__":
    main()

回答生成

ここでは、ユーザの質問を同じくベクトル化し、先ほど作成したベクトルデータベースの中から似ているデータを取得して、それを踏まえLLMで回答を生成していきます。
LLMもHuggingFaceからAPIで呼び出せるgoogle/gemma-2b-itモデルを使用していきます。

ここではいくつかのポイントに分けて解説していきます。

  • 関連データの取得
retriever = self.db.as_retriever(search_type="mmr", search_kwargs={'k': 4, 'fetch_k': 20})

fetch_k: 関連性の高い20件のページをまず検索し、その中から最も関連性の高い4件をユーザーに提示します。
k: 検索結果の上位4件を取得。
kの値が大きいと、関連性のないデータも含まれてきます。また、逆に少なすぎても関連性のあるデータを全て取得できないなどの問題があります。

  • temperature
self.repo_id = "google/gemma-2b-it"
self.llm = HuggingFaceEndpoint(
    repo_id=self.repo_id, max_length=1024, temperature=0.4, timeout=500
)

tempeature: モデルの生成するテキストの多様性や創造性を調整するパラメータ。
temperatureが0に近いほど、最も確率の高い単語のみを表示しますが、創造性に欠けます。また、1に近いほど確率の低い単語も表示し、創造的な回答になります。

  • プロンプト
# RAGプロンプト
self.prompt = PromptTemplate(
    template=(
        "You are a helpful assistant. Use the following context to answer the question.\n\n"
        "Context:\n{context}\n\n"
        "Question:\n{question}\n\n"
        "Answer with full detail and explanation:"
    )
)

# RAG対話型プロンプト
# self.prompt = ChatPromptTemplate.from_messages([
#     ("system", "You are a helpful assistant. Use the following context to answer the question."),
#     ("human", "Context: {context}. Question: {question}. Please answer with full detail and explanation:")
# ])

一貫した回答を生成するために、LangChainのPromptTemplateを使用して、LLMに渡す入力をテンプレート化します。
LangChainでは、複数のテンプレートが用意されていますが、ここでは主にPromptTemplatechatPromptTemplateの違いについてサラッと触れます。

PromptTemplate: 主に単純な質問応答や命令生成用。ロールは不要。
ChatPromptTemplate: 主に対話型アプリケーション用。system, human, assistantロールが必要。

回答生成のコード全体は以下になります。
また、のちに実装するUIからも呼び出せるようクラス化しておきます。

llm.py
import os
from dotenv import load_dotenv
from langchain_huggingface import HuggingFaceEndpoint, HuggingFaceEmbeddings
from langchain_chroma import Chroma 
from langchain.prompts import (
    PromptTemplate,
    ChatPromptTemplate,
)
from langchain_text_splitters import CharacterTextSplitter
from langchain_core.runnables import RunnablePassthrough


class RAGApp:
    def __init__(self, temperature=0.5):
        # env
        load_dotenv()
        self.HUGGING_FACE_API_KEY = os.getenv('HUGGING_FACE_API_KEY')

        if self.HUGGING_FACE_API_KEY:
            os.environ["HUGGINGFACEHUB_API_TOKEN"] = self.HUGGING_FACE_API_KEY
            print("Hugging Face API Key set successfully.")
        else:
            raise ValueError("HUGGING_FACE_API_KEY is not set in the .env file.")
        
        # HuggingFace Endpoint
        self.repo_id = "google/gemma-2b-it"
        self.llm = HuggingFaceEndpoint(
            repo_id=self.repo_id, max_length=1024, temperature=temperature, timeout=500
        )

        # Embedding function and database
        self.embedding_function = HuggingFaceEmbeddings(model_name="all-mpnet-base-v2")
        self.db = self.initialize_database()

        # RAGプロンプト
        self.prompt = PromptTemplate(
            template=(
                "You are a helpful assistant. Use the following context to answer the question.\n\n"
                "Context:\n{context}\n\n"
                "Question:\n{question}\n\n"
                "Answer with full detail and explanation:"
            )
        )

        # RAG対話型プロンプト
        # self.prompt = ChatPromptTemplate.from_messages([
        #     ("system", "You are a helpful assistant. Use the following context to answer the question."),
        #     ("human", "Context: {context}. Question: {question}. Please answer with full detail and explanation:")
        # ])
    
    def initialize_database(self, persist_directory="./chroma_db"):
        try:
            return Chroma(
                persist_directory=persist_directory,
                embedding_function=self.embedding_function,
            )
        except TypeError:
            print("Falling back to deprecated Chroma class.")
            from langchain_community.vectorstores import Chroma as DeprecatedChroma
            return DeprecatedChroma(
                persist_directory=persist_directory,
                embedding_function=self.embedding_function,
            )

    def __format_docs__(self, docs):
        if not docs:
            return "No relevant documents found."
        return "\n\n".join(doc.page_content for doc in docs)

    def __format_rag_input__(self, context, question):
        return self.prompt.format(context=context, question=question)

    def get_response(self, question):

        #-------- For testing (no RAG) ----
        # template = """Question: {question}
        # Answer: Answer with full detail and explanation"""
        # prompt = PromptTemplate.from_template(template)
        # llm_chain = (
        #     RunnablePassthrough() 
        #     | self.llm
        # )
        # result = llm_chain.invoke(question)
        #----------------------------------

        rag_chain = (
            RunnablePassthrough()  # Pass the formatted string directly
            | self.llm  # The HuggingFaceEndpoint processes the input string
        )

        # ドキュメントより関連データを取得
        retriever = self.db.as_retriever(search_type="mmr", search_kwargs={'k': 4, 'fetch_k': 20})
        docs = retriever.invoke(question)
        formatted_context = self.__format_docs__(docs)  # Format the documents into a context string
        
        # プロンプト生成
        formatted_input = self.__format_rag_input__(context=formatted_context, question=question)
        result = rag_chain.invoke(formatted_input)

        # 生成された回答から、不必要なコメントを削除
        if "The context explains that" in result:
            cleaned_response = result.replace("The context explains that ", "")
            return cleaned_response

        return result

def main():
    # インスタンス作成
    rag_app = RAGApp()

    # ユーザー入力
    question = input("Enter your question:")
    response = rag_app.get_response(question)
    print("Answer:")
    print(response)


if __name__ == "__main__":
    main()

上記のファイルをターミナルで実行し、質問を入力すると、結果が返ってきます。

python llm.py

Enter your question:connectome.design株式会社は、どんな人材を募集していますか
Number of requested results 20 is greater than number of elements in index 5, updating n_results = 5
Answer:

connectome.designは、コンサルタント、リサーチャー、ソフトウェアエンジニア、技術営業の4つの主要な人材を募集しています。

UI

せっかくなので、webアプリを作成できるstreamlitを使用してフロントエンドも整えます。

ui/main.py
import sys
import os
import streamlit as st
from streamlit_option_menu import option_menu  

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

from llm import RAGApp

# ユーザインターフェース
st.set_page_config(page_title="RAG Desktop App", layout="wide")

# サイドバー
with st.sidebar:
    selected = option_menu(
        menu_title="Menu",
        options=["Chat", "Settings"],
        icons=["chat", "gear"],  # Font Awesome
        menu_icon="menu-down",
        default_index=0,
    )

# Initialize session state
if "temperature" not in st.session_state:
    st.session_state.temperature = 0.3  # default temperature

# チャットレイアウト
if selected == "Chat":
    st.title("Chat with RAG Bot")
    
    # ユーザ入力
    st.markdown("### Enter your message:")
    user_input = st.text_area(
        label="",
        placeholder="Type your question here...",
        height=100,  
    )

    if st.button("Send"):
        if user_input.strip():
            # スピナー
            with st.spinner("Generating response..."):

                # DEBUG:
                print(f"##### Passing temp to RAG: {st.session_state.temperature} #####")

                # RAGApp
                rag_app = RAGApp(temperature=st.session_state.temperature)
                response = rag_app.get_response(user_input)

            st.markdown("### Bot's response:")
            st.markdown(response, unsafe_allow_html=True)  

# 設定レイアウト
elif selected == "Settings":
    st.title("Settings")
    st.write("Here you can configure the app settings.")
    st.write("#### Adjust the creativity of responses:")
    # Temperature スライダー
    st.session_state.temperature = st.slider(
        "Temperature (1-10):",  # ラベル
        min_value=0.1,          # 最小値
        max_value=1.0,          # 最大値
        value=st.session_state.temperature,  # デフォルト temperature
        step=0.1                # ステップ
    )

アプリを立ち上げる際は、streamlitのコマンドを使用します。

streamlit run ui/main.py

課題

一通りLangChainを使用したRAGアプリを開発してみましたが、LangChainには会話履歴を保存して応答生成に再利用できる機能もあることが分かりました。今後は、この機能も追加で実装していきたいです。
また、パラメータの変更等無しで、もっと手軽に利用したい場合はstreamlitより、slackからチャットボットを呼び出せるようにしたほうが合ってそうです。

CODブログ

Discussion