🌿

Spring AIでLangChainのクックブックにあるRAGをなぞる

2024/03/10に公開

概要

LangCainのクックブックにRAGがあるので、その内容をSpring AIでなぞってみました。

前提

検索から生成の流れ

LangChainではLCELというDSLを使って処理の流れを次のように書けます。

chain = (
    {"context": retriever, "question": RunnablePassthrough()}
    | prompt
    | model
    | StrOutputParser()
)

当初、処理の流れはわかるけれどパイプ演算が何をやっているのかがわからずモヤモヤしてたので、まずはそこを理解することにしました。

Pythonは__or__メソッドと__ror__メソッドを定義することでパイプ演算の動作を定義できます(__ror__メソッドは右辺がレシーバーになります)。

LCELはRunnableというクラスが処理の単位になりますが、このRunnableクラスに__or__メソッドと__ror__メソッドが定義されているため、パイプ演算で処理の流れを構築できます。

つまり、前述のLCELで書かれたコードは次のコードと等価です。

chain = (
    prompt
        .__ror__({"context": retriever, "question": RunnablePassthrough()})
        .__or__(model)
        .__or__(StrOutputParser())
)

なお、__or__メソッドと__ror__メソッドはRunnable以外にも引数で受け取れる型があり、内部で適用されているcoerce_to_runnable関数によって次のように変換されます。

引数の型 変換後の型
ジェネレーター関数 RunnableGenerator
関数(呼び出し可能オブジェクト) RunnableLambda
辞書 RunnableParallel

そのため__ror__メソッドに渡されている{"context": retriever, "question": RunnablePassthrough()}は内部でRunnableParallelへ変換されます。

ここまで理解してようやくスッキリしました。

Javaは独自にパイプ演算を定義できないので、愚直にコードを書いていきます。

ベクトルストアの構築

LangChainのクックブックではFaissというMeta社製のベクトル検索が行えるライブラリーを使用してベクトルストアを構築しています。
OpenAIを使用して"harrison worked at kensho"という内容のドキュメントをベクトル化してストアへ持たせています。

vectorstore = FAISS.from_texts(
    ["harrison worked at kensho"], embedding=OpenAIEmbeddings()
)

これとほぼ同様のことをSpring AIで行うのが次のコードです。

@Bean
SimpleVectorStore simpleVectorStore(EmbeddingClient embeddingClient) {
    SimpleVectorStore vectorStore = new SimpleVectorStore(embeddingClient);
    vectorStore.add(List.of(new Document("harrison worked at kensho")));
    return vectorStore;
}

SimpleVectorStoreは外部ライブラリーを用いずスクラッチで書かれたシンプルなベクトルストアです[1]

SimpleVectorStoreを構築するにはEmbeddingClientが必要ですが、依存関係にspring-ai-openai-spring-boot-starterを追加しているのでOpenAiEmbeddingClientがインジェクションされます。

検索

検索はVectorStoresimilaritySearchメソッドで行います。

List<Document> docs = vectorStore.similaritySearch(question);
Iterator<Document> iter = docs.iterator();
if (!iter.hasNext()) {
    // 検索に何もヒットしなかった
    return "I do not know.";
}

similaritySearchメソッドはデフォルトだと類似度の閾値は設定されず、上位4件までドキュメントを返します。
この辺りのパラメーターはSearchRequestクラスで細かく設定できます。

なお、docsIteratorにしているのは、あとで1件目の結果が欲しいからです[2]

生成

検索結果をコンテキストとし、質問と合わせてプロンプトを構築してChatClientcallメソッドへ渡しています。
callメソッドからは生成されたテキストが返されます。

String context = iter.next().getContent();
String prompt = """
        Answer the question based only on the following context:
        %2$s

        Question: %1$s
        """.formatted(question, context);
String answer = chatClient.call(prompt);

検索のときと同様に生成も細かくパラメーター設定が可能です。
生成の場合はPromptクラスが持つChatOptionsで設定します[3]

なお、ChatClientの実装クラスはOpenAiChatClientです。

まとめ

以上でLangChainのクックブックにあるRAGをSpring AIでなぞれました。
最もシンプルな例ではありますが、Spring AIを使うことで簡単にRAGできることがわかりました。
Spring AIの今後の機能拡充も楽しみです。

ソースコード

Spring Web MVCでHTTPエンドポイントを作成し、curlで動作確認できるようにしたものです。

https://github.com/backpaper0/spring-ai-example/tree/blog/rag

脚注
  1. とりあえず動かしたいときにこういう前準備なくすぐに使える部品があると嬉しくて、さすがSpringだなと思いました。 ↩︎

  2. 1件目を取得するためにdocs.get(0)するのではなくIteratorを使用するのは私の好みです。 ↩︎

  3. OpenAI用のChatOptions実装クラスはOpenAiChatOptionsです。 ↩︎

Discussion