♻️

【例外処理】画像生成エラー時に代替モデルへフォールバック【429エラー対策】

2024/05/19に公開

はじめに

こんにちわ、こんばんわ。
平(@tairanobuhiko)と申します。
今回は個人開発アプリ【しりとり画像ジェネレーター】の工夫したポイントでもある429エラーのハンドリング(例外処理)についてとりあげたいと思います。
https://word-chain-image-generator.onrender.com/

開発環境

言語フレームワーク バージョン
Ruby on Rails 7.0.8.1

背景(課題)

無料の画像生成API(Stable Dffusion)を使用しているため、リクエスト制限エラー(429)が多発し、ペルソナユーザー(小学校低学年層)の興味を惹く機能である画像生成の精度が不安定になってしまうという課題があり、安定的なサービスの提供とUX向上を目的として対策を講じる必要がありました。

講じた対策

無料のAPIモデルは活かしつつ、リクエスト制限に達した場合のみ別の代替モデルへフォールバックを行い画像生成機能の精度を安定させる方針としました。
代替モデルはOpenAIが提供する有料モデルの DALL-E3 を採用しました。

記事の結論

行なっていることは至極単純で、APIロジックの中でリクエスト制限に達していた場合はレスポンスコード(HTTPステータスコード:429)をコントローラーへ返却しフォールバック用のif文が発動するという具合です。
数ある中の解の一つとして参考になると幸いです。

内容

コントローラー

まずはAPIを呼び込むコントローラー側の記述です。
処理としては、

  1. Stable Diffusionのモデルへ画像生成リクエスト
  2. リクエスト制限(429)に達していた場合はDALL-Eモデルへフォールバック

といったコードを書いています。

controller.rb
def generate_image(translated_words, generate_model = "Stable Diffusion")
  # 不適切な画像を生成しないようネガティブプロンプトを仕込んでいます
  negative_prompt = "EasyNegative, (worst quality, low quality:1.4), lowres, ugly, bad anatomy, nsfw((Not safe for work)), low quality, negative hand-neg, bad anatomy ,extra fingers, fewer fingers, missing fingers ,extra arms, fewer arms, missing arms, extra legs, fewer legs, extra legs ,text ,logo, watermark, text, word, monochrome, rainbow, wood"
  image_bytes, http_status = StableDiffusionService.query(translated_words, negative_prompt)

  if http_status == 429 # Stable Diffusionがレートエラー時はDALL-E3にフォールバック
    generate_model = "dall-e-3"
    image_bytes, http_status = DalleService.query(translated_words, generate_model)
  end
  if http_status == 429 # DALL-E3がレートエラー時はDALL-E2にフォールバック
    generate_model = "dall-e-2"
    image_bytes, http_status = DalleService.query(translated_words, generate_model)
  end
  return image_bytes, http_status, generate_model
end

API

Stable Diffusion

処理の内容としては、リクエスト制限以外のエラー(500や503)が生じた場合は指数バックオフ方式(2秒〜32秒、最大5回)でリトライを行うロジックにしています。
リクエスト制限エラー(429)が生じた場合は即時コントローラー側へレスポンスコード(429)を返却し、代替先モデルへフォールバックする仕様です。

stable_diffusion_service.rb
require 'httparty'
require 'base64'

class StableDiffusionService
  # カスタム例外クラスの定義
  class RetryableError < StandardError; end

  class ServiceError < StandardError
    attr_reader :status_code

    def initialize(message, status_code = nil)
      super(message)
      @status_code = status_code
    end
  end

  API_URL = "https://api-inference.huggingface.co/models/stablediffusionapi/breakdomainxl-v6"
  HEADERS = {
    "Authorization" => ENV['STABLE_DIFFUSION_API_KEY']
  }
  TIMEOUT_SECONDS = 60 # タイムアウト時間の設定
  MAX_RETRY_ATTEMPTS = 4 # 最大リトライ回数

  def self.query(prompt, negative_prompt = nil, start_time = Time.now, retry_count = 0)
    payload = if negative_prompt
                { inputs: prompt, parameters: { negative_prompt: negative_prompt } }
              else
                { inputs: prompt }
              end

    begin
      response = HTTParty.post(API_URL, body: payload.to_json, headers: HEADERS, timeout: TIMEOUT_SECONDS)
      if response.code == 200
        return response.body, response.code
      elsif response.code == 429
        puts "Stable Diffusionのレート制限に達しました。DALL-Eモデルへフォールバックします。"
        return "", response.code
      else
        raise RetryableError, "API response code: #{response.code}"
      end
    rescue Net::ReadTimeout, HTTParty::Error, RetryableError => e
      if retry_count <= MAX_RETRY_ATTEMPTS
        retry_count += 1
        sleep_time = 2 ** retry_count # 指数バックオフでの待機時間を設定
        puts "リトライします。エラー: #{e.message}、リトライ回数: #{retry_count}、次のリトライまでの待機時間: #{sleep_time}秒"
        sleep sleep_time
        query(prompt, negative_prompt, start_time, retry_count)
      else
        return "", response&.code || 500
      end
    rescue StandardError => e
      puts "予期しないエラーが発生しました: #{e.message}"
      return "", response&.code || 500
    end
  end
end

DALL-E3(DALL-E2)

公式ドキュメント
https://platform.openai.com/docs/guides/images

Stable Diffusionと違い、DALL-Eモデルはサーバーエラーが生じにくいので指数バックオフ方式でのリトライは講じていません。
また、モデルの選択をAPIメソッドの呼び出し時に変数で値渡しすることで(DALL-E3でもDALL-E2でも)ロジックの再利用ができるようにしています。

dalle_service.rb
require 'httparty'
require 'base64'

class Dalle3Service
  API_URL = "https://api.openai.com/v1/images/generations"
  HEADERS = {
    "Authorization" => "Bearer #{ENV['OPENAI_API_KEY']}",
    "Content-Type" => "application/json"
  }
  TIMEOUT_SECONDS = 60

  def self.query(prompt, model)
    payload = {
      model: model,
      prompt: prompt,
      n: 1,
      size: "1024x1024",
      quality: "standard"
    }

    begin
      response = HTTParty.post(API_URL, body: payload.to_json, headers: HEADERS, timeout: TIMEOUT_SECONDS)
      if response.code == 200
        image_url = response.parsed_response["data"][0]["url"]
        image_response = HTTParty.get(image_url)
        return image_response.body, response.code
      elsif response.code == 429
        puts "DALL-E3のレート制限に達しました。DALL-E2モデルへフォールバックします。"
        ['x-ratelimit-limit-requests', 'x-ratelimit-remaining-requests', 'x-ratelimit-reset-requests', 'x-ratelimit-limit-tokens', 'x-ratelimit-remaining-tokens', 'x-ratelimit-reset-tokens'].each do |header|
          puts "#{header}: #{response.headers[header]}" if response.headers[header]
        end
        return "", response.code
      else
        raise "API response code: #{response.code}"
      end
    rescue Net::ReadTimeout, HTTParty::Error => e
      puts "APIリクエスト中にエラーが発生しました: #{e.message}"
      return "", response&.code || 500
    rescue StandardError => e
      puts "予期しないエラーが発生しました: #{e.message}"
      return "", response&.code || 500
    end
  end
end

さいごに

対策を二重に重ねて生成モデルを3つ用意しましたが、今日現在でいまのところ3つ目のDALL-E2までフォールバックした事例はありませんでした(笑)
ただ、エラーハンドリングを行うことでユーザーの興味が削がれる事象を排除することができるのでユーザースケールしていくにあたってはこだわって良かったポイントだとは思います。

少しでも誰かしらの参考になると嬉しいです。

X(旧Twitter)もやっているのでよかったらフォローお願いします🔥
https://x.com/tairanobuhiko

Discussion