🔁

[Kotlin]Amazon BedrockのInvoke Model APIをConverse APIへ移行

2024/07/23に公開


スマートラウンドでエンジニアをやっている福本です!

先日、AWSが主催する「Amazon Bedrock Prototyping Camp」というAmazon Bedrockの1dayワークショップに参加してきました👇

https://note.com/smartround/n/nb0ad875e2b02

スマートラウンドでは、社内でLLMに関する取り組みをすでに行っており、Amazon Bedrockを用いてAnthropic(Claude)のAPIを叩くことでLLMにアクセスしています。

既存のコードではInvoke Model APIを利用していたのですが、先ほどのワークショップ内で、新しいAPIであるConverse APIを教えて頂きました。

ワークショップの場にAWSのSAの方がたくさん居らっしゃったので、「新しいAPIに移行しておこう!」と思い、ワークショップの当日にサッと移行しました。SAの方とも相談しながら進めたので、移行の内容についてメモを書いておきます✍️

Converse APIとは??

いきなり実装の話に入る前に、Converse APIについて理解を深めておきます。


「Amazon Bedrock Prototyping Camp」で投影されたスライド(公開OKなのは確認済)

Converse APIは2024年5月に公開されたAPIのようで、概要は上記の資料、およびAWSのドキュメントが全てです...と言いたいのですが、咀嚼すると👇

  • モデルの切り替えがラクになるので、検証がしやすくなる
    • モデルごとにリクエスト・レスポンスの構造が変わらなくなるため
  • 画像のBase64エンコードが不要になる
  • Invoke Modelで提供される機能の多くはそのまま利用可能(ストリーミングやTool useなど)
    • 逆に、画像生成やEmbeddingは2024/07/18時点では未対応(以下参照)

The Converse API doesn't support any embedding models (such as Titan Embeddings G1 - Text) or image generation models (such as Stability AI).

https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#conversation-inference-supported-models-features

  • [余談]Invoke APIよりもToo many requestsエラー(後述)が起きにくくなる(らしい)
    • (Converse APIにかかわらず)新しいAPIやモデルの方が、多くリクエストをさばけるように設計されているみたいです

移行した内容

スマートラウンドでは、バックエンドにKotlinを用いてプロダクト開発をしております。

その都合上、今回の内容についてはKotlinでAWS SDK for Javaを用いたコードを示しておりますが、言語ごとに対応するAWS SDKを利用すれば問題なく扱えると思います。

利用するライブラリ

AWS SDK for Javaでは2.25.63でAmazon Bedrock RuntimeにConverse APIが追加されたようなので、それ以降のバージョンを使ってください。

スマートラウンドではパッケージ管理にGradleを使っていますが、以下のように指定しています。

implementation group: 'software.amazon.awssdk', name: 'bedrockruntime', version: '2.26.16'

コード

以下のようなコードを書いて、Converse APIを叩いています。
モデルID(MODEL_ID)やリージョンなどのパラメータについては、使いたい任意のものを使用してください。

※実際にはファイルを分けてる箇所がありますが、記事だとわかりづらくなるので一緒に書いています
※実際に動いているコードから改変している箇所があります

/*
* AWS SDKを利用してAmazon Bedrock経由でClaude3にリクエストを送信する自前クライアント
*/
class AmazonBedrockClient {
  private val client = BedrockRuntimeAsyncClient.builder()
    .credentialsProvider(DefaultCredentialsProvider.create())
    .region(Region.US_WEST_2)
    .build()

  companion object {
    private const val MODEL_ID = "anthropic.claude-3-opus-20240229-v1:0"
    private const val MAX_RETRIES = 10
    private const val INITIAL_RETRY_INTERVAL_MS = 1000L
    private const val MAX_TOKENS = 4096
    private const val TEMPERATURE: Float = 0.0f
  }

  fun sendRequest(userPrompt: String, systemPrompt: String): String {
    val messages = listOf(AmazonBedrockMessage().createMessage(ConversationRole.USER, userPrompt),)

    repeat(MAX_RETRIES) { attempt ->
      try {
        return executeRequest(systemPrompt, messages)
      } catch (e: ThrottlingException) {
        handleRetry(attempt, e)
      }
    }
    throw RuntimeException("Amazon Bedrock: APIリクエストに $MAX_RETRIES 回失敗しました")
  }

  /*
   Converse APIを用いてBedrock経由でClaudeのAPIリクエストを送る
  */
  private fun executeRequest(systemPrompt: String, messages: List<Message>): String {
    val future = CompletableFuture<String>()
    val request = converseRequest(messages, systemPrompt)

    request.whenComplete { response: ConverseResponse?, error: Throwable? ->
      if (error == null && response != null) {
        future.complete(parseResponse(response))
      } else {
        future.completeExceptionally(error)
      }
    }
    return future.getOrThrow()
  }

  /*
    Converse APIのリクエストを作って送信する
  */
  private fun converseRequest(messages: List<Message>, systemPrompt: String): CompletableFuture<ConverseResponse> {
    return client.converse(
      ConverseRequest.builder()
        .messages(messages)
        .modelId(MODEL_ID)
        .system(SystemContentBlock.fromText(systemPrompt))
        .inferenceConfig(
          InferenceConfiguration.builder()
            .maxTokens(MAX_TOKENS)
            .temperature(TEMPERATURE)
            .build()
        )
        .build()
    )
  }

  private fun CompletableFuture<String>.getOrThrow(): String {
    return try {
      this.get()
    } catch (e: Exception) {
      throw e
    }
  }

  private fun parseResponse(response: ConverseResponse): String {
    return response.output().message().content()[0].text()
  }

  private fun handleRetry(attempt: Int, e: ThrottlingException) {
    if (attempt >= MAX_RETRIES - 1) throw e
    val sleepTime = calculateExponentialBackoff(attempt)
    Thread.sleep(sleepTime)
  }

  // Too many requestsのエラーが起きるので、リトライ時に指数バックオフを使用して待つ
  private fun calculateExponentialBackoff(attempt: Int): Long {
    return INITIAL_RETRY_INTERVAL_MS * 2.0.pow(attempt).toLong()
  }
}

class AmazonBedrockMessage {
  fun createMessage(role: ConversationRole, promptContent: String): Message {
    return Message.builder()
      .content(ContentBlock.fromText(promptContent.trimEnd())) // プロンプトの末尾に空白や改行があるとエラーになるのでtrimする
      .role(role)
      .build()
  }
}

上記のクラスを利用してAPIリクエストを送信するときは、以下のようにします。

val userPrompt = "hogehoge"
val systemPrompt = "hogehoge"
val result = AmazonBedrockClient().sendRequest(userPrompt, systemPrompt)

以下で個別に解説していきます。

Converse APIのリクエストを送信する部分

今回メインとなるConverse APIのリクエストを送信するコードは、以下の部分です👇

AWS SDKの公式ドキュメントにサンプルコードがありますので、多くの部分でそちらを参考にしています。AWSのリソースに感謝!

具体的には、BedrockRuntimeAsyncClient.converseメソッドを使うことで、APIリクエストを送信できます。

/*
Converse APIのリクエストを作って送信する
*/
private fun converseRequest(messages: List<Message>, systemPrompt: String): CompletableFuture<ConverseResponse> {
  return client.converse(
    ConverseRequest.builder()
      .messages(messages)
      .modelId(MODEL_ID)
      .system(SystemContentBlock.fromText(systemPrompt))
      .inferenceConfig(
        InferenceConfiguration.builder()
          .maxTokens(MAX_TOKENS)
          .temperature(TEMPERATURE)
          .build()
        )
    .build()
  )
}

BedrockRuntimeAsyncClientのインスタンスは以下で定義しています👇
設定するリージョンは、使用したいモデルIDに対応しているものかを注意して指定して下さい。

private val client = BedrockRuntimeAsyncClient.builder()
    .credentialsProvider(DefaultCredentialsProvider.create())
    .region(Region.US_WEST_2)
    .build()

秘匿情報はDefaultCredentialsProvider.create()でいい感じに設定されます。スマートラウンドでは、このコードAWS Lambdaで実行されるのですが、LambdaにAmazon BedrockのRoleを割り当てることで実行できるようにしています。

リクエストを受け取る

上記のリクエストを送信したリクエストをあーだこーだする箇所は以下です👇
CompletableFutureを使って、APIの実行結果を待って受け取っています。

/*
Converse APIを用いてBedrock経由でClaudeのAPIリクエストを送る
*/
private fun executeRequest(systemPrompt: String, messages: List<Message>): String {
  val future = CompletableFuture<String>()
  val request = converseRequest(messages, systemPrompt)

  request.whenComplete { response: ConverseResponse?, error: Throwable? ->
    if (error == null && response != null) {
      future.complete(parseResponse(response))
    } else {
      future.completeExceptionally(error)
    }
  }
  return future.getOrThrow()
}

private fun CompletableFuture<String>.getOrThrow(): String {
  return try {
    this.get()
  } catch (e: Exception) {
    throw e
  }
}

private fun parseResponse(response: ConverseResponse): String {
  return response.output().message().content()[0].text()
}

(おまけ)リトライ

当初Invoke Model APIを利用していたときの実装ですが、自分なりに調査して対応したのでメモとして残しておきます。

Amazon Bedrockを実践投入し運用したところ、 Too many requestsのメッセージで ThrottlingExceptionがたまに出る現象に見舞われました👇

com.amazonaws.services.bedrockruntime.model.ThrottlingException: Too many requests, please wait before trying again. You have sent too many requests.  Wait before trying again. (Service: AmazonBedrockRuntime; Status Code: 429; Error Code: ThrottlingException; Request ID: xxxx; Proxy: null)

どうも調べていると、色々な方が同じ現象に遭っているみたいですね。

サッと調べてみたところ、Claude 2系では「トークン数(max_tokens_to_sample)が原因」という記載が多かったです。しかし、トークン数の上限は既に指定している & 上限に引っかかるようなトークン数でリクエスト↔レスポンスはしていないはずなので、こちらが原因ではなさそう...と考えました👇

https://dev.classmethod.jp/articles/tsnote-amazon-bedrock-claude-2-1-how-to-solve-throttling-exception-error/

https://repost.aws/ja/questions/QUC82MTlWlQNagsqEG2Hbxlw/aws-bedrock-throttlingexception-occurs-randomly-for-claude-2-1-runtime?sc_ichannel=ha&sc_ilang=en&sc_isite=repost&sc_iplace=hp&sc_icontent=QUC82MTlWlQNagsqEG2Hbxlw&sc_ipos=13

そんな中、以下の発表資料でClaude 3を使ったAPIリクエストにおいて、「リトライの実装をすることで Too many requests を解消した」という知見がありました(ありがとうございます 🙏 )

https://speakerdeck.com/sonoda_mj/bedrocknotoo-many-requestjie-jue-sitemita?slide=17

そもそもToo many requestsの話は抜きにしても、リトライ機構自体は当然あったほうが良いので、この機会に実装してしまおう!と考えました。

というわけで、以下の箇所でリトライを行っております。指数バックオフも入れてみています。

ユースケース的にAPIリクエストが完了するまでの速度がさほど重視されていない関係上、設定しているリトライ回数やインターバル時間はあまりチューニングしていません。Amazon SQSを用いてAWS Lambdaでコードを実行する場合は、SQS可視性タイムアウトLambda関数のタイムアウトの設定が、リトライの実行で想定される時間と矛盾しないか注意してください。

fun sendRequest(userPrompt: String): String {
  val systemPrompt = "hogehoge"
  val messages = listOf(AmazonBedrockMessage().createMessage(ConversationRole.USER, userPrompt),)

  repeat(MAX_RETRIES) { attempt ->
    try {
      return executeRequest(systemPrompt, messages)
    } catch (e: ThrottlingException) {
       handleRetry(attempt, e)
    }
  }
  throw RuntimeException("Amazon Bedrock: APIリクエストに $MAX_RETRIES 回失敗しました")
}

private fun handleRetry(attempt: Int, e: ThrottlingException) {
  if (attempt >= MAX_RETRIES - 1) throw e
  val sleepTime = calculateExponentialBackoff(attempt)
  Thread.sleep(sleepTime)
}

// Too many requestsのエラーが起きるので、リトライ時に指数バックオフを使用して待つ
private fun calculateExponentialBackoff(attempt: Int): Long {
  return INITIAL_RETRY_INTERVAL_MS * 2.0.pow(attempt).toLong()
}

このリトライの対応とConverse APIへの移行が終わってからは、今のところ Too many requestsに見舞われておりません。お疲れ様でした!

※実行している環境や送信するリクエストなどに依存すると思うので、Too many requestsが解消することを保証するものではありません

さいごに

以上、これでConverse APIでAmazon Bedrockを使えます🙆やったね!

Invoke APIのコードと比較すると、build()するクラスが増えて少しモッサリした感じを個人的には受けますが、これでモデルの切り替えがラクになる等のメリットを享受できるので、やっておいて損はないと思います。

世間的にもLLMを活用した実践的な取り組みやその知見が色々と世に出てきだしていますが、スマートラウンドとしては知見がまだ少ない領域で、開発を手探りで進めている部分もあります。

この辺りの取り組みをされている方、もしくは興味があるよという方は、ぜひ私とカジュアルトークしましょう💬 最後まで読んでいただきありがとうございました!

https://jobs.smartround.com/

スマートラウンド テックブログ

Discussion