📚

Stable Diffusion を AWS でサクッと動かすために Sagemaker JumpStart を使ってみる

2023/03/04に公開

はじめに

今回は、今話題の Stable Diffusion をサクッと使いたい方向けに、SageMaker JumpStart を使った環境構築の方法を試してみます。お手元の環境に GPU が無くても、AWS のリソースを使うことで簡単に検証環境を作ることができます!

SageMaker JumpStart

  • SageMaker Stduio の画面を開く
  • サイドバーのアイコンを押し、[SageMaker JumpStart] 画面を表示する
  • 右上の検索欄から [Naclbit Trinart Stable Diffusion V2]を検索し、選択する

    ※ Trinart モデルは 2 次元のイラストなどに強いモデルです

Notebook を開く

  • [Open Notebook] を押す

Notebook を実行する

Notebook が開くので、1. Set Up から順番に実行していきます。
特にコードを買い換える必要はなく、ただ実行するだけです!

3. Retrieve Artifacts & Deploy an Endpoint

[3. Retrieve Artifacts & Deploy an Endpoint] について少し解説します。
この部分のコードによって、事前学習済みのモデルを使って SageMaker 上に推論用のエンドポイントを起動しています。

inference_instance_type で、推論用のインスタンスタイプを定義しており、デフォルトではml.p3.2xlargeが使用されています。

from sagemaker import image_uris, model_uris, script_uris, hyperparameters
from sagemaker.model import Model
from sagemaker.predictor import Predictor
from sagemaker.utils import name_from_base

endpoint_name = name_from_base(f"jumpstart-example-{model_id}")

inference_instance_type = "ml.p3.2xlarge"

# Retrieve the inference docker container uri. This is the base HuggingFace container image for the default model above.
deploy_image_uri = image_uris.retrieve(
    region=None,
    framework=None,  # automatically inferred from model_id
    image_scope="inference",
    model_id=model_id,
    model_version=model_version,
    instance_type=inference_instance_type,
)

# Retrieve the inference script uri. This includes all dependencies and scripts for model loading, inference handling etc.
deploy_source_uri = script_uris.retrieve(
    model_id=model_id, model_version=model_version, script_scope="inference"
)


# Retrieve the model uri. This includes the pre-trained nvidia-ssd model and parameters.
model_uri = model_uris.retrieve(
    model_id=model_id, model_version=model_version, model_scope="inference"
)


# Create the SageMaker model instance
model = Model(
    image_uri=deploy_image_uri,
    source_dir=deploy_source_uri,
    model_data=model_uri,
    entry_point="inference.py",  # entry point file in source_dir and present in deploy_source_uri
    role=aws_role,
    predictor_cls=Predictor,
    name=endpoint_name,
)

# deploy the Model. Note that we need to pass Predictor class when we deploy model through Model class,
# for being able to run inference through the sagemaker API.
model_predictor = model.deploy(
    initial_instance_count=1,
    instance_type=inference_instance_type,
    predictor_cls=Predictor,
    endpoint_name=endpoint_name,
)

画像を生成してみる

これで準備は完了です。
以下のコードの text 部分の文字列が画像生成用のプロンプトを指定する部分です。
実行してみると、指定したテキストから画像が生成されました!

text = "cottage in impressionist style"
query_response = query(model_predictor, text)
img, prmpt = parse_response(query_response)
display_img_and_prompt(img, prmpt)

詳細なパラメータを設定する

また、詳細なパラメータを指定して、生成される画像をチューニングすることも可能です。

import json

payload = {
    "prompt": "astronaut on a horse",
    "width": 400,
    "height": 400,
    "num_images_per_prompt": 2,
    "num_inference_steps": 50,
    "guidance_scale": 7.5,
}


def query_endpoint_with_json_payload(model_predictor, payload):
    """Query the model predictor with json payload."""

    encoded_payload = json.dumps(payload).encode("utf-8")

    query_response = model_predictor.predict(
        encoded_payload,
        {
            "ContentType": "application/json",
            "Accept": "application/json",
        },
    )
    return query_response


def parse_response_multiple_images(query_response):
    """Parse response and return generated image and the prompt"""

    response_dict = json.loads(query_response)
    return response_dict["generated_images"], response_dict["prompt"]


query_response = query_endpoint_with_json_payload(model_predictor, payload)
generated_images, prompt = parse_response_multiple_images(query_response)

for img in generated_images:
    display_img_and_prompt(img, prompt)

おわりに

SageMaker Jumpstart を使うことで、自身でコードを書く事なく Stable Diffusion を動かすことができました!
今回は ml.p3.2xlarge インスタンスを使っているため、推論エンドポイント起動中は 1 時間あたり 5.242 USD の課金が発生していますので、検証が終わったら 6. Clean up the endpoint のコードを実行して、エンドポイントを削除してください。

Discussion