🤖

LangChain4jのStream処理をFs2.Streamで扱う

2024/12/23に公開

前書き

この記事は
Scala Advent Calendar 2024
asakatsu Advent Calendar 2024
の 24日目の記事です.

本編

Javaで時折このようなinterfaceの利用を強制されるときがあるので備忘録です。
reactive-streamsもこんな感じらしいです。

LangChain4j

LangChain4jはGPTなどLLM系のWrapperライブラリです
https://docs.langchain4j.dev/tutorials/response-streaming/

public interface StreamingResponseHandler<T> {
    void onNext(String token);
    default void onComplete(Response<T> response)
    void onError(Throwable error);
}
  • token毎にonNextが実行される
  • 全て処理されたときにonCompleteが実行される
  • エラーが発生したときにonErrorが実行される

といったシンプルなinterfaceです。

しかし、利用例を見ていると、
model.generateの戻り値もvoidで、StreamingResponseHandlerの各々の戻り値もvoidなのでどうしようと悩んでしまいます。

StreamingChatLanguageModel model = OpenAiStreamingChatModel.builder()
    .apiKey(System.getenv("OPENAI_API_KEY"))
    .modelName(GPT_4_O_MINI)
    .build();

String userMessage = "Tell me a joke";

model.generate(userMessage, new StreamingResponseHandler<AiMessage>() {

    @Override
    public void onNext(String token) {
        System.out.println("onNext: " + token);
    }

    @Override
    public void onComplete(Response<AiMessage> response) {
        System.out.println("onComplete: " + response);
    }

    @Override
    public void onError(Throwable error) {
        error.printStackTrace();
    }
});

それぞれ戻り値の型がvoidだと、工夫しない限りsub processで処理されてしまいます。(バックグラウンドで Stream が処理されて main processが終了してしまう)

main processでクライアントにStreamを返却してChatGPTのようにリアルタイムで文字の生成が行われるような処理を実装できるようにするのがこの記事のゴールです。

依存

  • scala 3.5.0
  • "co.fs2" %% "fs2-core" % "3.11.0"
  • "dev.langchain4j" % "langchain4j" % "0.34.0"
  • "dev.langchain4j" % "langchain4j-open-ai" % "0.34.0"

実装

CatsEffectのQueueを使って簡単に解決!
以下を見てください!みんなが思い描いているQueueです!
QueueはStreamに変換することができます!(Stream.fromQueueNoneTerminated(queue))
https://typelevel.org/cats-effect/docs/std/queue

実装は以下にしたんで解釈してください!

import cats.effect.*
import cats.effect.std.Queue
import dev.langchain4j.data.message.AiMessage
import dev.langchain4j.model.StreamingResponseHandler
import dev.langchain4j.model.openai.{
  OpenAiChatModelName,
  OpenAiStreamingChatModel
}
import dev.langchain4j.model.output.Response
import fs2.*

import cats.effect.unsafe.implicits.global

class GPTAdapterImplFs2Stream {
  val apiKey = ""

  def createStream(message: String): IO[Stream[IO, String]] =
    for {
      // StreamingResponseHandlerを実行する前にqueueを作成しておく
      queue <- Queue.unbounded[IO, Option[String]]

      // StreamingResponseHandlerの作成
      handler = new StreamingResponseHandler[AiMessage] {
        
        // onNextが呼ばれたらqueueに値を追加
        override def onNext(token: String): Unit = {
          queue.offer(Some(token)).unsafeRunSync()
        }

        // onCompleteが呼ばれたらqueueにNoneを追加
        override def onComplete(response: Response[AiMessage]): Unit = {
          queue.offer(None).unsafeRunSync()
        }

        // onErrorが呼ばれたらqueueにNoneを追加
        override def onError(error: Throwable): Unit = {
          queue.offer(None).unsafeRunSync()
        }
      }
      
      // queueを作成する際のIOと生成処理のIOを合成
      _ <- IO.delay {
        val model: OpenAiStreamingChatModel = OpenAiStreamingChatModel
          .builder()
          .apiKey(apiKey)
          .modelName(OpenAiChatModelName.GPT_4_O_MINI)
          .build()

        model.generate(message, handler)
      }

      // queueをNoneが来たら終了するStreamに変換。
    } yield Stream.fromQueueNoneTerminated(queue)
}

Discussion