⛓️

LangChainにコントリビュートした話

2024/09/01に公開

はじめに

ストリーツ株式会社の@hanamaです。
今回は、LangChainというOSSにコントリビュートした話を書きたいと思います。

LangChainとは

LangChainとは、LLM分野で広く利用されているライブラリです。LLMを絡めた複雑なワークフローの実装が可能で、ToolやRAGなどの追加機能も簡単に導入することができます。
執筆時点でGithubリポジトリのstar数は91.4k、毎週100~200コミットが行われている非常に活発なOSSです。

コントリビュートのきっかけ

弊社のプロダクトはメディア向けの生成AIサービスであり、ある程度まとまった長さの文章を生成させる必要がありました。多くのLLMでは、出力文章の長さが一定の値を超えると、そこで出力が強制的にストップしてしまいます。そのため、出力が停止した理由を確認し、長さが原因だった場合は続きを生成させる必要があります。このような仕組みを考える際に、merge_message_runs関数が複数メッセージの結合において改行を強制的に挟む仕様が問題となりました。

生成AIサービスの出力限界を超えて続きを生成させる処理については調べてもあまり先行事例が見つからなかったので、後ほど別記事にまとめたいと思います。

実際のコード
for msg in messages:
    curr = msg.copy(deep=True)
    last = merged.pop() if merged else None
    if not last:
        merged.append(curr)
    elif isinstance(curr, ToolMessage) or not isinstance(curr, last.__class__):
        merged.extend([last, curr])
    else:
        last_chunk = _msg_to_chunk(last)
        curr_chunk = _msg_to_chunk(curr)
        if (
            isinstance(last_chunk.content, str)
            and isinstance(curr_chunk.content, str)
            and last_chunk.content
            and curr_chunk.content
        ):
            last_chunk.content += "\n"  # ここで強制的に改行が挟まれる
        merged.append(_chunk_to_msg(last_chunk + curr_chunk))

解決方針

自分の環境では、改行が強制的に入ってしまうのが問題だったため、はじめは改行の挿入をオプトアウトする引数をmerge_message_runs関数に追加したPRを作成しました。

def merge_message_runs(
    messages: Union[Iterable[MessageLikeRepresentation], PromptValue],
    *,
    with_newline_separator: bool = True,  # 改行挿入をオプトアウトするための引数を追加
) -> List[BaseMessage]:

しかし、この変更に対して、「チャンクの区切り文字をユーザーが自由に指定できた方が良いのではないか?」というフィードバックをメンテナの@baskaryanさんから頂き、merge_message_runs関数にセパレーターを指定できる引数を追加することにしました。(comment)
自身に起こっている問題を解決するだけでなく、どんなユーザーにとっても使いやすい関数にするという観点でのフィードバックを頂き、大変勉強になりました。

実装の詳細

今回のコントリビュートでは、merge_message_runs関数がchunk_separator引数を受け取り、それを適切に処理する変更と、そのテストを追加しました。

merge_message_runs関数の変更

変更内容
def merge_message_runs(
    messages: Union[Iterable[MessageLikeRepresentation], PromptValue],
+    *,
+    chunk_separator: str = "\n",  # メッセージの区切り文字を自由に指定。デフォルトは "\n"
) -> List[BaseMessage]:

# 省略

for msg in messages:
    curr = msg.copy(deep=True)
    last = merged.pop() if merged else None
    if not last:
        merged.append(curr)
    elif isinstance(curr, ToolMessage) or not isinstance(curr, last.__class__):
        merged.extend([last, curr])
    else:
        last_chunk = _msg_to_chunk(last)
        curr_chunk = _msg_to_chunk(curr)
        if (
            isinstance(last_chunk.content, str)
            and isinstance(curr_chunk.content, str)
            and last_chunk.content
            and curr_chunk.content
        ):
-            last_chunk.content += "\n"
+            last_chunk.content += chunk_separator  # "\n" ではなく `chunk_separator`を挿入するように変更
        merged.append(_chunk_to_msg(last_chunk + curr_chunk))

docstringの変更も含めても、このファイルの変更は十数行で済みました。これだけ簡単な変更でも、ユーザー視点から見ればmerge_message_runs関数の自由度が上がり、手前味噌ながら便利になったなと感じます。
1点だけポイントを挙げると、chunk_separator引数の前に*を付けている部分です。これにより、chunk_separator引数はキーワード引数として明示的に指定しないと関数に渡せないようになっています。普通にpythonを書いている時にわざわざこの表記を用いることは少ないですが、OSSで提供する関数の場合は、オプショナルな設定項目はこのような形で提供することが多いと思っています。実際、今回私がコントリビュートしたlangchain_coreの関数の多くも、このような形でオプショナルな引数を提供していたため、それに合わせる形で実装しました。

テストの追加

今回の変更では、merge_message_runs関数に引数を追加したので、その引数が正しく機能するかを確認するテストを追加しました。

追加したテストケースは以下の通りです。

  • chunk_separator引数に"<sep>"を指定した場合、メッセージの区切り文字が"<sep>"になることを確認するテスト
  • chunk_separator引数に""を指定した場合、メッセージの区切り文字が空文字になることを確認するテスト
テストの追加
+@pytest.mark.parametrize("msg_cls", [HumanMessage, AIMessage, SystemMessage])
+def test_merge_message_runs_str_with_specified_separator(
+    msg_cls: Type[BaseMessage],
+) -> None:
+    messages = [msg_cls("foo"), msg_cls("bar"), msg_cls("baz")]
+    messages_copy = [m.copy(deep=True) for m in messages]
+    expected = [msg_cls("foo<sep>bar<sep>baz")]
+    actual = merge_message_runs(messages, chunk_separator="<sep>")
+    assert actual == expected
+    assert messages == messages_copy
+
+
+@pytest.mark.parametrize("msg_cls", [HumanMessage, AIMessage, SystemMessage])
+def test_merge_message_runs_str_without_separator(
+    msg_cls: Type[BaseMessage],
+) -> None:
+    messages = [msg_cls("foo"), msg_cls("bar"), msg_cls("baz")]
+    messages_copy = [m.copy(deep=True) for m in messages]
+    expected = [msg_cls("foobarbaz")]
+    actual = merge_message_runs(messages, chunk_separator="")
+    assert actual == expected
+    assert messages == messages_copy

LangChainのテストでは、pytest.mark.parametrizeが活用されていました。BaseMessageクラスを継承している複数のクラスに対して同じテストケースを適用でき、非常に便利だったので、弊社のプロダクトでも活用していきたいと思いました。

感想

私の出したPRは無事マージされ、日本時間の8/26にリリースされたlangchain-core v0.2.35からmerge_message_runs関数でchunk_separator引数を利用できるようになりました。個人としてOSSにコントリビュートするのは初めてではなかったものの、結構久しぶりの経験だったので改めてよかったことと困ったことを書いておこうと思います。

よかったこと

OSS運営にも生成AIが取り込まれている

LangChainのリポジトリには、DosuBotが入っていました。多くのユーザーが日々作成するたくさんのissueやPull Requestに対してこのBotがトリアージを行ったり適切なラベルを付けてくれるので、メンテナの負荷軽減に寄与していると思いました。また、このリポジトリでは、ディスカッションも活発に動いているのですが、Q&Aのトピックについては、一次解答をDosuBotが行っているようでした。

Github Actionsを活用したCI/CD

LangChainのリポジトリのCI/CDにはGithub Actionsが活用されていました。弊社のリポジトリでもGithub Actionsを使っているのですが、matrix機能は使ったことがありませんでした。様々なバージョン、プラットフォームにおいて正常に動作することを確認する際にすごく便利だと思い、早速弊社のリポジトリにも導入しました。

困ったこと

今回のコントリビュートで少し困ったことが一つだけありました。それはレビュワーとの連絡が一時的に取れなくなったことです。
これだけ巨大なOSSのメンテナなので、日々大量のメンションが飛んでおり、それぞれに優先度をつけて対応されていると思います。これに関しては私の力でどうにかなる問題ではないのですが、数日おきにリマインドを送るなどして出来るだけ気づいてもらえるように工夫しました。結果的にはPRを作成して約3週間後にマージされました。

まとめ

今回はLangChainにコントリビュートした話ということで、コントリビュートのきっかけと変更内容、感想をまとめてみました。
LangChainは非常に活発なOSSであり、日々新しいissueが立てられ、コミュニティでも活発な議論が行われているので、また時間があればコントリビュートしたいと思っています。

ストリーツ株式会社

Discussion