👋

Amazon Bedrockを活用したRAGアプリケーションの構築

2024/08/30に公開

Amazon Bedrockを活用したRAGアプリケーションの構築手順

こんにちは。
Amazon Bedrockを使うと簡単にRAGアプリが作れるらしい...」ということで、以下のAmazon Bedrockの書籍をベースにして実際にRAGアプリを作ってみたので、内容を簡単に共有します。

Amazon Bedrock 生成AIアプリ開発入門

細かい説明は割愛していますので、気になる方は上記の書籍を参考にしてもらえるとよいと思います。

1. Amazon Bedrockとは?

Amazon Bedrockは、AWSが提供する最新の生成AI基盤サービスです。複数のファウンデーションモデルを簡単に利用できるようにし、AIアプリケーションの開発を支援します。
今回は、このBedrockの機能の一つである「Knowledge bases for Amazon Bedrock」を使用して、GUIの操作だけで簡単にRAGシステムを構築する方法を見ていきます。

最後に、Djangoを使ってRAGアプリを開発してみたので実際のデモもご覧ください。

2. RAGアプリケーションの構築手順

Amazon Bedrockの「Knowledge bases for Amazon Bedrock」を使うとGUIベースでポチポチするだけでRAG機能が簡単に作れます。

AmazonのS3にPDFファイルをアップロードして、そのデータをソースとしたRAGを構築する場合の構築ステップを図に表すと以下のような5ステップで簡単に作れます。

注意点として、デフォルト設定で作るとAmazon OpenSearch Serverless でベクトルデータベースが作られますが、これが1日5ドル以上かかるらしいので個人利用にはちょっとコストが高すぎるので注意が必要のようです。

2.1 S3バケットの作成とデータの準備

まず、RAGシステムのナレッジベースとなるデータを保存するためのS3バケットを作成します。

  1. AWSコンソールからS3サービスにアクセスし、新しいバケットを作成します。
  2. 作成したバケットに、ナレッジベースとして使用するPDFファイルをアップロードします。

2.2 ナレッジベースの作成

次に、Amazon Bedrockサービス内でナレッジベースを作成します。

  1. AWSコンソールからAmazon Bedrockサービスにアクセスします。

  2. 「ナレッジベース」セクションから「ナレッジベースを作成」を選択します。
    ナレッジベースの作成

  3. データソースの設定を行います。先ほど作成したS3バケットを指定します。
    データソースの設定

  4. 埋め込みモデルを選択し、ベクトルストアを設定します。
    埋め込みモデルとベクトルストアの設定

  5. ナレッジベースを作成し、S3のデータを同期します。

RAG基盤の構築自体はこれだけで簡単に作れてしまいます。
ホントにGUI操作でポチポチやるだけですね。

この状態でプレイグラウンド上ですぐにRAG機能を試すことができます(下図)。

2.3 フロントエンドの開発

参考書籍ではStreamlitでフロントエンドアプリを開発していますが、せっかくなので今回はDjangoを使用してチャットアプリを開発してみました。

以下に、主要なコードを掲載してありますので、Djangoに組み込めば同じアプリが作れると思います。

プロジェクト構成

C:.
│  db.sqlite3
│  manage.py
│
├─config
│  │  asgi.py
│  │  settings.py
│  │  urls.py
└─rag_app
    │  admin.py
    │  apps.py
    │  models.py
    │  tests.py
    │  utils.py
    │  views.py
    │  __init__.py
    ├─templates
    │  └─rag_app
    │          index.html

urls.py

from django.contrib import admin
from django.urls import path
from rag_app import views

urlpatterns = [
    path('admin/', admin.site.urls),
    path('', views.index, name='index'),
    path('ask-question/', views.ask_question, name='ask_question'),
]

views.py

from django.shortcuts import render
from django.http import StreamingHttpResponse
from django.http import JsonResponse
from django.views.decorators.csrf import csrf_exempt
from .utils import get_bedrock_streaming_response
import json

def index(request):
    return render(request, 'rag_app/index.html')

@csrf_exempt
def ask_question(request):
    if request.method == 'POST':
        data = json.loads(request.body)
        question = data.get('question', '')
        return StreamingHttpResponse(get_bedrock_streaming_response(question), content_type='text/event-stream')
    return StreamingHttpResponse('Invalid request method', status=400)

utils.py

from langchain_aws import ChatBedrock
from langchain_aws.retrievers import AmazonKnowledgeBasesRetriever
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
import boto3
import json
import chardet

def get_full_document_content(s3_uri):
    parts = s3_uri.replace("s3://", "").split("/")
    bucket = parts[0]
    key = "/".join(parts[1:])

    s3 = boto3.client('s3')

    try:
        response = s3.get_object(Bucket=bucket, Key=key)
        content = response['Body'].read()
        
        encoding = chardet.detect(content)['encoding']
        decoded_content = content.decode(encoding)
        return decoded_content
    except Exception as e:
        print(f"Error retrieving document from S3: {e}")
        return None

def get_bedrock_streaming_response(question):
    session = boto3.Session(
        aws_access_key_id='AWSのアクセスキーを指定',
        aws_secret_access_key='アクセスキーのシークレットキーを指定',
        region_name='リージョンを指定'
    )

    retriever = AmazonKnowledgeBasesRetriever(
        knowledge_base_id="bedrockナレッジベースのIDをここに設定する‘‘",
        retrieval_config={"vectorSearchConfiguration": {"numberOfResults": 10}},
        aws_session=session
    )

    prompt = ChatPromptTemplate.from_template(
        "以下のcontextに基づいて回答してください: {context} / 質問: {question}"
    )

    model = ChatBedrock(
        model_id="anthropic.claude-3-sonnet-20240229-v1:0",
        model_kwargs={"max_tokens": 1000},
        streaming=True
    )

    chain = (
        {"context": retriever, "question": RunnablePassthrough()}
        | prompt
        | model
        | StrOutputParser()
    )

    context = retriever.invoke(question)

    debug_info = {
        "retrieved_documents": [
            {
                "content": doc.page_content,
                "metadata": doc.metadata
            } for doc in context
        ]
    }

    yield f"data: {json.dumps({'type': 'debug_info', 'content': debug_info})}\n\n"

    for chunk in chain.stream(question):
        yield f"data: {json.dumps({'type': 'content', 'text': chunk})}\n\n"

    yield "data: [DONE]\n\n"

index.html

<!DOCTYPE html>
<html lang="ja">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Bedrock RAG Assistant</title>
    <link href="https://cdnjs.cloudflare.com/ajax/libs/tailwindcss/2.2.19/tailwind.min.css" rel="stylesheet">
    <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.3/css/all.min.css">
    <style>
        :root {
            --accenture-purple: #A100FF;
            --accenture-black: #000000;
            --accenture-gray: #484B4D;
        }
        
        body {
            font-family: 'Graphik', 'Helvetica Neue', Arial, sans-serif;
            background-color: #F0F2F5;
            color: var(--accenture-black);
        }
        
        .container {
            max-width: 1000px;
        }
        
        #answer, #debugInfo, #loading {
            transition: all 0.3s ease-in-out;
        }

        #debugInfoContent {
            max-height: 600px;
            overflow-y: auto;
        }

        .document-item {
            background-color: #FFFFFF;
            border-radius: 12px;
            box-shadow: 0 4px 10px rgba(0,0,0,0.1);
            margin-bottom: 1.5rem;
            overflow: hidden;
            border: 2px solid var(--accenture-purple);
            transition: all 0.3s ease;
        }

        .document-item:hover {
            box-shadow: 0 6px 15px rgba(161, 0, 255, 0.2);
            transform: translateY(-2px);
        }

        .document-header {
            cursor: pointer;
            padding: 1.25rem;
            background-color: #F8F0FF;
            display: flex;
            justify-content: space-between;
            align-items: center;
            transition: background-color 0.3s ease;
        }

        .document-header:hover {
            background-color: #F0E0FF;
        }

        .document-content {
            display: none;
            padding: 1.25rem;
        }

        .document-content.show {
            display: block;
        }

        .full-text {
            max-height: 300px;
            overflow-y: auto;
            background-color: #2D2D2D;
            color: #FFFFFF;
            padding: 1.25rem;
            border-radius: 8px;
            margin-top: 0.75rem;
            font-size: 0.95rem;
            line-height: 1.7;
        }

        .metadata-item {
            display: inline-block;
            padding: 0.4rem 0.8rem;
            border-radius: 20px;
            font-size: 0.85rem;
            font-weight: 600;
            margin-right: 0.75rem;
            margin-bottom: 0.75rem;
        }

        .metadata-item-score {
            background-color: #FFE082;
            color: #000000;
        }

        .metadata-item-chunk {
            background-color: #81D4FA;
            color: #000000;
        }

        .accenture-container {
            background-color: #FFFFFF;
            box-shadow: 0 6px 15px rgba(0,0,0,0.1);
            border: none;
            border-radius: 16px;
            overflow: hidden;
            transition: all 0.3s ease;
        }

        .accenture-container:hover {
            box-shadow: 0 8px 20px rgba(161, 0, 255, 0.15);
        }

        .accenture-header {
            background-color: var(--accenture-purple);
            color: #FFFFFF;
            padding: 1.25rem 1.5rem;
            font-size: 1.4rem;
            font-weight: bold;
        }

        .btn-accenture {
            background-color: var(--accenture-purple);
            transition: all 0.3s ease;
            font-size: 1.1rem;
            font-weight: 600;
            letter-spacing: 0.5px;
            padding: 0.75rem 2rem;
            border-radius: 30px;
        }

        .btn-accenture:hover {
            background-color: #8200CC;
            transform: translateY(-2px);
            box-shadow: 0 4px 15px rgba(161, 0, 255, 0.3);
        }

        .accent-border {
            border-bottom: 3px solid var(--accenture-purple);
        }

        #question {
            border: 2px solid #E0E0E0;
            transition: all 0.3s ease;
        }

        #question:focus {
            border-color: var(--accenture-purple);
            box-shadow: 0 0 0 3px rgba(161, 0, 255, 0.2);
        }

        @keyframes spin {
            0% { transform: rotate(0deg); }
            100% { transform: rotate(360deg); }
        }

        .animate-spin {
            animation: spin 1s linear infinite;
        }

        .typing-indicator::after {
            content: '|';
            animation: blink 0.7s infinite;
        }

        @keyframes blink {
            0% { opacity: 0; }
            50% { opacity: 1; }
            100% { opacity: 0; }
        }

        #loading {
            position: fixed;
            top: 50%;
            left: 50%;
            transform: translate(-50%, -50%);
            z-index: 1000;
            text-align: center;
        }

        #loading p {
            margin-top: 1rem;
            font-weight: 600;
            color: var(--accenture-purple);
        }
    </style>
</head>
<body class="bg-gray-50">
    <div class="container mx-auto px-4 py-16">
        <h1 class="text-5xl font-bold text-gray-900 mb-16 text-center accent-border pb-6">Bedrock RAG Assistant</h1>
        <div class="accenture-container mb-12">
            <div class="accenture-header">質問入力</div>
            <div class="p-8">
                <textarea id="question" rows="4" class="w-full px-5 py-4 border rounded-xl focus:outline-none transition duration-300 resize-none text-lg"></textarea>
                <div class="flex justify-center mt-6">
                    <button id="askButton" class="btn-accenture text-white shadow-lg">
                        <i class="fas fa-question-circle mr-2"></i>質問する
                    </button>
                </div>
            </div>
        </div>
        <div id="answer" class="accenture-container mb-12 hidden">
            <div class="accenture-header">回答</div>
            <div class="p-8">
                <p id="answerText" class="text-gray-700 leading-relaxed text-lg"></p>
            </div>
        </div>
        <div id="debugInfo" class="accenture-container hidden">
            <div class="accenture-header">参照元文書情報</div>
            <div id="debugInfoContent" class="p-8"></div>
        </div>
    </div>
    <div id="loading" class="hidden">
        <div class="inline-block animate-spin rounded-full h-20 w-20 border-t-4 border-b-4 border-purple-500"></div>
        <p class="mt-4 text-xl font-semibold text-purple-600">回答を生成中です...</p>
    </div>
    <script src="https://cdnjs.cloudflare.com/ajax/libs/axios/0.21.1/axios.min.js"></script>
    <script>
        document.addEventListener('DOMContentLoaded', () => {
            const questionInput = document.getElementById('question');
            const askButton = document.getElementById('askButton');
            const answerDiv = document.getElementById('answer');
            const answerText = document.getElementById('answerText');
            const debugInfoDiv = document.getElementById('debugInfo');
            const debugInfoContent = document.getElementById('debugInfoContent');
            const loadingDiv = document.getElementById('loading');

            function createDocumentElement(doc, index) {
                const docElement = document.createElement('div');
                docElement.className = 'document-item';
                docElement.innerHTML = `
                    <div class="document-header">
                        <h3 class="text-xl font-semibold text-gray-800">
                            <i class="fas fa-file-alt mr-3 text-purple-600"></i>参照元 ${index + 1}
                        </h3>
                        <i class="fas fa-chevron-down text-gray-600"></i>
                    </div>
                    <div class="document-content">
                        <div class="mb-4">
                            <span class="metadata-item metadata-item-score"><strong>関連性スコア:</strong> ${doc.metadata.score.toFixed(4)}</span>
                            <span class="metadata-item metadata-item-chunk"><strong>チャンクID:</strong> ${doc.metadata.source_metadata['x-amz-bedrock-kb-chunk-id']}</span>
                        </div>
                        <p class="text-sm text-gray-700 mb-4"><strong>ソース:</strong> ${doc.metadata.source_metadata['x-amz-bedrock-kb-source-uri']}</p>
                        <div>
                            <h4 class="text-lg font-semibold text-gray-800 mb-3">参照テキスト全文:</h4>
                            <div class="full-text">
                                ${doc.content}
                            </div>
                        </div>
                    </div>
                `;
                return docElement;
            }

            function toggleDocumentContent(header) {
                const content = header.nextElementSibling;
                content.classList.toggle('show');
                const icon = header.querySelector('i.fas');
                icon.classList.toggle('fa-chevron-down');
                icon.classList.toggle('fa-chevron-up');
            }

            function formatAnswer(text) {
                const sentences = text.match(/[^。!?]+[。!?]/g) || [];
                let formattedText = '';
                let lineCount = 0;

                for (let i = 0; i < sentences.length; i++) {
                    formattedText += sentences[i];
                    lineCount++;

                    if (lineCount === 5 || (lineCount === 1 && sentences[i].length > 50)) {
                        formattedText += '<br><br>';
                        lineCount = 0;
                    }
                }

                return formattedText.trim();
            }

            askButton.addEventListener('click', async () => {
                const question = questionInput.value.trim();
                if (question) {
                    try {
                        answerDiv.classList.add('hidden');
                        debugInfoDiv.classList.add('hidden');
                        loadingDiv.classList.remove('hidden');
                        answerText.innerHTML = '';
                        answerText.classList.add('typing-indicator');
                        
                        const response = await fetch('/ask-question/', {
                            method: 'POST',
                            headers: {
                                'Content-Type': 'application/json',
                            },
                            body: JSON.stringify({ question: question }),
                        });

                        const reader = response.body.getReader();
                        const decoder = new TextDecoder();
                        let isFirstContent = true;
                        let fullAnswer = '';

                        while (true) {
                            const { done, value } = await reader.read();
                            if (done) break;

                            const chunk = decoder.decode(value);
                            const lines = chunk.split('\n');

                            for (const line of lines) {
                                if (line.startsWith('data: ')) {
                                    const data = line.slice(6);
                                    if (data === '[DONE]') {
                                        break;
                                    }

                                    try {
                                        const parsedData = JSON.parse(data);
                                        if (parsedData.type === 'debug_info') {
                                            debugInfoContent.innerHTML = '';
                                            parsedData.content.retrieved_documents.forEach((doc, index) => {
                                                const docElement = createDocumentElement(doc, index);
                                                debugInfoContent.appendChild(docElement);
                                            });
                                            debugInfoDiv.classList.remove('hidden');

                                            document.querySelectorAll('.document-header').forEach(header => {
                                                header.addEventListener('click', () => toggleDocumentContent(header));
                                            });
                                        } else if (parsedData.type === 'content') {
                                            if (isFirstContent) {
                                                loadingDiv.classList.add('hidden');
                                                answerDiv.classList.remove('hidden');
                                                isFirstContent = false;
                                            }
                                            fullAnswer += parsedData.text;
                                            answerText.innerHTML = formatAnswer(fullAnswer);
                                        }
                                    } catch (e) {
                                        console.error('Error parsing chunk:', e);
                                    }
                                }
                            }
                        }
                    } catch (error) {
                        console.error('Error:', error);
                        alert('エラーが発生しました。もう一度お試しください。');
                    } finally {
                        loadingDiv.classList.add('hidden');
                        answerText.classList.remove('typing-indicator');
                    }
                }
            });

            questionInput.addEventListener('keypress', (event) => {
                if (event.key === 'Enter' && !event.shiftKey) {
                    event.preventDefault();
                    askButton.click();
                }
            });
        });
    </script>
</body>
</html>

utils.pyがamazon bedrockのRAG基盤にリクエストを投げている部分です。
書籍のコードをベースに少しカスタマイズしてありますが、このコードの主な流れは以下の通りです。

  1. S3からドキュメントの内容を取得する関数 get_full_document_content を定義。
  2. Bedrockを使用してストリーミングレスポンスを生成する関数 get_bedrock_streaming_response を定義。
  3. AmazonKnowledgeBasesRetrieverを使用して、質問に関連する文脈を取得。
  4. 取得した文脈と質問を組み合わせてプロンプトを作成。
  5. Claude 3 Sonnetモデルを使用して回答を生成。
  6. 生成された回答をストリーミング形式で返却。

以下のようなDjango&Amazon BedrockのRAGアプリが完成しました。
ちなみに、このDjangoアプリ自体はClaudeを活用してコードはほぼ自動生成してもらい1時間ちょっとで完成しました。

3. まとめ

Amazon Bedrockを使用することで、比較的簡単にRAGアプリケーションを構築できることがわかりました。RAG機能自体はGUI操作で簡単に作れるので、実際に動くRAGをサクッと作ってプレイグラウンド上で動作を検証してみたいという場合にはものすごくとっつき易いかなと思います。

今回の様に、スクラッチでWEBアプリ化するにはある程度開発スキルが必要ですが、Amazon Q Businessを利用すると生成AI搭載のチャットボットがすうステップで簡単に作れるようです。

Amazon Bedrockは比較的新しいサービスですが、今後ますます発展していくことが予想されます。AIアプリケーションの開発に興味がある方は、ぜひ一度試してみてください。

Accenture Japan (有志)

Discussion