💅

Replicate APIを利用して、サムネイルを自動生成

2024/04/24に公開

概要

前回の記事では、タイトルと文章を入力すると、AIがタイトルの内容に基づいて自動的に文章を補完してくれるブログ記事作成ツールを作成しました。今回は、そのツールにタイトルの内容に基づいたサムネイル画像を自動生成する機能を追加しました。
https://zenn.dev/atwillatwill/articles/6e852348010d59

全てのコードは以下のgithubにあります。
https://github.com/atwill1028/trycopilotkit

使い方

1.タイトルを入力します。

2.「サムネイル作成ボタン」を押します。

3.タイトルに基づいたサムネイルが自動生成されます。

使用するツール

Replicate

Replicateは、オープンソースのAIモデルにアクセスできるプラットフォームです。一定の利用量まで無料で使用可能です。今回は、Stable Diffusionモデルの一種である「stability-ai / sdxl」を利用します。
https://replicate.com/
https://replicate.com/stability-ai/sdxl

必要な準備

Replicate API TOKENを取得して設定

  • Replicate API TOKENを取得し、環境変数に設定します。
  • .env.localファイルを作成し、REPLICATE_API_TOKENを設定します。

パッケージをインストール

  • ターミナルで以下を実行します。
pnpm install replicate

next.config.mjsを修正

  • next.config.mjsファイルを以下のコードで置き換えます。

指定されたドメイン(replicate.comとreplicate.delivery)からの画像のみを許可するように設定します。

next.config.mjs
/** @type {import('next').NextConfig} */
const nextConfig = {
  reactStrictMode: true,
  images: {
    remotePatterns: [
      {
        protocol: "https",
        hostname: "replicate.com",
      },
      {
        protocol: "https",
        hostname: "replicate.delivery",
      },
    ],
  },
};

export default nextConfig;

実装

バックエンド

POST /api/predictions

  • フロントエンドから送信されたテキストプロンプトを受け取り、Replicate APIを使用してStable Diffusionモデルに画像生成リクエストを送信します。
  • リクエストが成功すると、予測IDを含むレスポンスを返します。

GET /api/predictions/:id

  • フロントエンドから送信された予測IDを受け取り、Replicate APIを使用して予測の状態を取得します。
  • 予測が完了していれば、生成された画像のURLを含むレスポンスを返します。
src/app/api/predictions/route.ts
import Replicate from "replicate";

const replicate = new Replicate({
  auth: process.env.REPLICATE_API_TOKEN,
});

export async function POST(req: Request) {
  const data = await req.formData();
  if (!process.env.REPLICATE_API_TOKEN) {
    throw new Error(
      "The REPLICATE_API_TOKEN environment variable is not set. See README.md for instructions on how to set it."
    );
  }

  const prediction = await replicate.predictions.create({
    // Pinned to a specific version of Stable Diffusion
    // See https://replicate.com/stability-ai/sdxl
    version: "8beff3369e81422112d93b89ca01426147de542cd4684c244b673b105188fe5f",

    // This is the text prompt that will be submitted by a form on the frontend
    input: { prompt: data.get("prompt") },
  });

  if (prediction?.error) {
    return new Response(JSON.stringify({ detail: prediction.error.detail }), {
      status: 500,
    });
  }

  return new Response(JSON.stringify(prediction), { status: 201 });
}
src/app/api/predictions/id/route.ts
import Replicate from "replicate";

const replicate = new Replicate({
  auth: process.env.REPLICATE_API_TOKEN,
});

export async function GET(
  request: Request,
  { params }: { params: { id: string } }
) {
  const prediction = await replicate.predictions.get(params.id);

  if (prediction?.error) {
    return new Response(JSON.stringify({ detail: prediction.error.detail }), {
      status: 500,
    });
  }

  return new Response(JSON.stringify(prediction), { status: 200 });
}

フロントエンド

  • "サムネイル作成"ボタンをクリックすると、handleClick関数が呼び出されます。
  • fetchを使用して、バックエンドの/api/predictionsエンドポイントにPOSTリクエストを送信し、タイトルをリクエストボディに含めます。
  • レスポンスが成功した場合、予測の状態をpredictionに設定し、画像生成の進行状況を監視します。
  • 生成された画像のURLを取得し、Imageコンポーネントを使用して表示します。
  • エラーが発生した場合は、エラーメッセージを表示します。
src/app/components/Article.tsx
"use client";
import { Input } from "@/components/ui/input";
import { CopilotTextarea } from "@copilotkit/react-textarea";
import { useState } from "react";
import Image from "next/image";
import { RotateCcw } from "lucide-react";
import { Prediction } from "replicate";
import { Button } from "@/components/ui/button";
const sleep = (ms: number) => new Promise((r) => setTimeout(r, ms));

export function Article() {
  const [title, setTitle] = useState("");
  const [text, setText] = useState("");
  const [prediction, setPrediction] = useState<Prediction | null>(null);
  const [error, setError] = useState(null);

  const handleClick = async () => {
    const response = await fetch("/api/predictions", {
      method: "POST",
      body: JSON.stringify(title),
    });

    let prediction = await response.json();
    if (response.status !== 201) {
      setError(prediction.detail);
      return;
    }
    setPrediction(prediction);

    while (
      prediction.status !== "succeeded" &&
      prediction.status !== "failed"
    ) {
      await sleep(1000);
      const response = await fetch("/api/predictions/" + prediction.id, {
        cache: "no-store",
      });
      prediction = await response.json();
      if (response.status !== 200) {
        setError(prediction.detail);
        return;
      }
      console.log({ prediction });
      setPrediction(prediction);
    }
  };

  return (
    <div className="px-8 py-8">
      <div className="mt-4 flex flex-col items-center justify-center w-full">
        {error && <div className="mt-4 text-red-500">{error}</div>}
        {prediction && prediction.output ? (
          <div className="flex flex-col items-center justify-center w-full">
            <Image
              src={prediction.output[prediction.output.length - 1]}
              alt="output"
              width={350}
              height={350}
              className="object-cover rounded-md border-gray-300"
            />
          </div>
        ) : (
          <div className="flex flex-col items-center justify-center w-full">
            <Image
              src="/images/sample.png"
              alt="output"
              width={350}
              height={350}
              className="object-cover rounded-md border-gray-300"
            />
          </div>
        )}
      </div>
      <div className="flex space-x-2">
        <Input
          className="mb-8"
          value={title}
          onChange={(e) => setTitle(e.target.value)}
          placeholder="タイトルを書こう"
        />
        {prediction ? (
          prediction.output ? (
            <Button onClick={handleClick} className="w-1/3">
              別のサムネイル作成
            </Button>
          ) : (
            <Button disabled className="w-1/3">
              <RotateCcw className="mr-2 h-4 w-4 animate-spin" />
              作成中
            </Button>
          )
        ) : (
          <Button onClick={handleClick} className="w-1/3">
            サムネイル作成
          </Button>
        )}
      </div>
      <CopilotTextarea
        className="px-4 py-4 text-lg border-2 h-48 border-gray-300 rounded-lg focus:outline-none"
        value={text}
        onValueChange={(value: string) => setText(value)}
        placeholder="本文を書こう"
        autosuggestionsConfig={{
          // タイトルの内容に基づいて文章が補完されるように設定
          textareaPurpose: `research a blog article topic on ${title}`,
          chatApiConfigs: {
            suggestionsApiConfig: {
              forwardedParams: {
                max_tokens: 20,
                stop: ["\n", ".", ",", "?", "!"],
              },
            },
          },
          debounceTime: 250,
        }}
      />
    </div>
  );
}

参考

https://replicate.com/docs/get-started/nextjs

Discussion