🙆‍♀️

画像生成AIの生成途中経過を可視化したい!【Diffusers】

2024/08/06に公開

はじめに

みなさん画像生成AIのStable Diffusionはご存知でしょうか。
私の過去の記事でも画像生成AIをGoogle Colabで簡単に使えるように解説しています。
https://zenn.dev/asap/articles/aa856819e2a722

Stable Diffusionといえば、WebUIが有名ですが、Diffusersというライブラリを利用することで、Pythonを利用して、AIをサービスに組み込んだりすることが可能です。
今回は、画像生成AIが絵を生成するまでの途中経過を可視化してみたいと思います。
(Web UIでは生成途中の画像を可視化する機能はありますが、diffusersではあまり解説されていないように思います)

生成途中の画像を可視化する方法は、複数パターンがあります、
今回は、下記の4パターンの方法で、生成途中を可視化します。
(一番下のものが、皆さんが馴染み深い画像かもしれません)

生成途中の潜在表現からVAEで再構成した画像の表示

生成途中の潜在表現を直接線形近似した画像を表示
(潜在表現の縦横サイズは、元の画像の1/8なので小さい画像です)

潜在表現に対する各Stepの更新幅を変更して、VAEで再構成した画像の表示
(これだけ事情があってSamplerが違うため、生成される画像が異なります)

Samplerが生成途中に推定したクリーン画像の表示

成果物

下記のリポジトリをご覧ください。
https://github.com/personabb/colab_AI_sample/tree/main/colab_SDXLControlNet_sample_forkDiffuser

今回の実験

下記に実施した実験の内容を記載します。実験結果については最後にご紹介しています。

  • 実験1
    • 生成途中の潜在表現からVAEで再構成した画像の表示
  • 実験2
    • 生成途中の潜在表現を直接線形近似した画像を表示
  • 実験3
    • 潜在表現に対する各Stepの更新幅を変更して、VAEで再構成した画像の表示
  • 実験4
    • Samplerが生成途中に推定したクリーン画像の表示

事前準備

利用するLoRAモデルを保存する

Part10の記事をご覧ください。

参照画像をダウンロードする

コントロールネットの入力に利用する画像を取得し、後述するフォルダの「inputs」フォルダに格納してください。


(今回の実験で利用している画像に関しては、AI生成画像ですが、版権画像なのでここで提示するのは、変換後の深度画像のみとします。)
元の画像はANIMAGINE XL 3.1の公式が提供しているチュートリアルで使われているプロンプトで作成した画像の一つです。(おそらくseed42-47のあたり)

解説

下記の通り、解説を行います。
まずは上記のリポジトリをcloneしてください。

./
git clone https://github.com/personabb/colab_AI_sample.git

その後、cloneしたフォルダ「colab_AI_sample」をマイドライブの適当な場所においてください。

ディレクトリ構造

Google Driveのディレクトリ構造は下記を想定します。

MyDrive/
    └ colab_AI_sample/
          └ colab_SDXLControlNet_sample_forkDiffuser/
                  ├ configs/
                  |    └ config.ini
                  ├ inputs/
                  |    | refer.webp
                  |    └ DreamyvibesartstyleSDXL.safetensors
                  ├ outputs/
                  ├ module/
                  |    └ module_sd3c.py
                  └ SDXLControlNet_sample.ipynb

  • colab_AI_sampleフォルダは適当です。なんでも良いです。1階層である必要はなく下記のように複数階層になっていても良いです。
    • MyDrive/hogehoge/spamspam/hogespam/colab_AI_sample
  • outputsフォルダには、生成後の画像が格納されます。最初は空です。
    • 連続して生成を行う場合、過去の生成内容を上書きするため、ダウンロードするか、名前を変えておくことをオススメします。
  • inputsフォルダには、ControlNetで利用する参照画像を格納しています。詳細は後述します。
    • 加えて、先ほどダウンロードしたLoRAモデルも格納します
      • 名前に空白が入っているのが気持ち悪かったのでリネームしています。

使い方解説

SDXLControlNet_sample.ipynbをGoogle Colabratoryアプリで開いてください。
ファイルを右クリックすると「アプリで開く」という項目が表示されるため、そこからGoogle Colabratoryアプリを選択してください。

もし、ない場合は、「アプリを追加」からアプリストアに行き、「Google Colabratory」で検索してインストールをしてください。

Google Colabratoryアプリで開いたら、SDXLControlNet_sample.ipynbのメモを参考にして、一番上のセルから順番に実行していけば、問題なく最後まで動作して、画像生成をすることができると思います。

また、最後まで実行後、パラメータを変更して再度実行する場合は、「ランタイム」→「セッションを再起動して全て実行する」をクリックしてください。

コードの解説を後回しに、とにかく実験をしたい方は、実験の章まで飛ばしてください

コード解説

主に、重要なSDXLControlNet_sample.ipynbmodule/module_sdc.pyについて解説します。

SDXLControlNet_sample.ipynb

該当のコードは下記になります。
https://github.com/personabb/colab_AI_sample/blob/main/colab_SDXLControlNet_sample_forkDiffuser/SDXLControlNet_sample.ipynb

基本的にはPart10と同じですが、パッケージインストール部分が異なります。
具体的には下記のようになっています。

1セル目

%rm -r /content/diffusers-preview_latents
%cd /content/
!git clone https://github.com/personabb/diffusers-preview_latents.git
import sys
sys.path
sys.path.append('/content/diffusers-preview_latents/src')

ここでは、「diffusers-preview_latents」というリポジトリをクローンしてきています。
本リポジトリは、通常のDiffusersライブラリをForkして作成しており、通常では取得できない値を取得できるように改造しています。
大元のDiffusersライブラリのver0.29.0をForkしています。
詳細は後述します。
その上で、クローンしたリポジトリを環境変数に追加することで、別のスクリプトからDiffusersモジュールとして読み込めるようにしています

また、2セル目のパッケージインストールに関しては、これまではDiffusersが入っていたと思いますが、今回はDiffusersを利用しないので、取り除いています。

加えて、5セル目の設定ファイルの部分も多少変わっています。変わっているのは下記の部分です。

use_dpm_solver = False
save_latent_simple = False
save_latent_overstep = False
save_latent_approximation = False

save_predict_skip_x0 = False

一つ目において、これまではSamplerに「DPMSolverMultistepScheduler」を利用していたのですが、こちらでは後述する実験において、不都合があったため、use_dpm_solver = Falseとして、「EulerAncestralDiscreteScheduler」が利用されるように変更しています。

2つ目においては、どの部分をTrueに変更するかで、保存される途中経過が変わります。実験の際に紹介します。

その他のセルは、Part10のものと同様です。

module/module_sdc.py

続いて、SDXLControlNet_sample.ipynbから読み込まれるモジュールの中身を説明します。

下記にコード全文を示します。

コード全文
./colab_AI_sample/colab_SDXLControlNet_sample_forkDiffuser/module/module_sdc.py

from diffusers import DiffusionPipeline, AutoencoderKL, StableDiffusionXLControlNetPipeline, ControlNetModel
from diffusers.utils import load_image
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
import torch
from diffusers.schedulers import DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler
from controlnet_aux.processor import Processor

import os
import configparser
# ファイルの存在チェック用モジュール
import errno
import cv2
from PIL import Image
import time
import numpy as np

class SDXLCconfig:
    def __init__(self, config_ini_path = './configs/config.ini'):
        # iniファイルの読み込み
        self.config_ini = configparser.ConfigParser()

        # 指定したiniファイルが存在しない場合、エラー発生
        if not os.path.exists(config_ini_path):
            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), config_ini_path)

        self.config_ini.read(config_ini_path, encoding='utf-8')
        SDXLC_items = self.config_ini.items('SDXLC')
        self.SDXLC_config_dict = dict(SDXLC_items)

class SDXLC:
    def __init__(self,device = None, config_ini_path = './configs/config.ini'):

        SDXLC_config = SDXLCconfig(config_ini_path = config_ini_path)
        config_dict = SDXLC_config.SDXLC_config_dict


        if device is not None:
            self.device = device
        else:
            device = config_dict["device"]

            self.device = "cuda" if torch.cuda.is_available() else "cpu"
            if device != "auto":
                self.device = device

        self.last_latents = None
        self.last_step = -1
        self.last_timestep = 1000

        self.n_steps = int(config_dict["n_steps"])
        if not config_dict["high_noise_frac"] == "None":
          self.high_noise_frac = float(config_dict["high_noise_frac"])
        else:
          self.high_noise_frac = None
        self.seed = int(config_dict["seed"])
        self.generator = torch.Generator(device=self.device).manual_seed(self.seed)

        self.controlnet_path = config_dict["controlnet_path"]

        self.control_mode = config_dict["control_mode"]
        if self.control_mode == "None":
            self.control_mode = None

        self.vae_model_path = config_dict["vae_model_path"]
        self.VAE_FLAG = True
        if self.vae_model_path == "None":
            self.vae_model_path = None
            self.VAE_FLAG = False

        self.base_model_path = config_dict["base_model_path"]

        self.REFINER_FLAG = True
        self.refiner_model_path = config_dict["refiner_model_path"]
        if self.refiner_model_path == "None":
            self.refiner_model_path = None
            self.REFINER_FLAG = False


        self.LORA_FLAG = True
        self.lora_weight_path = config_dict["lora_weight_path"]
        if self.lora_weight_path == "None":
          self.lora_weight_path = None
          self.LORA_FLAG = False
        self.lora_scale = float(config_dict["lora_scale"])

        self.use_dpm_solver = config_dict["use_dpm_solver"]
        if self.use_dpm_solver == "True":
            self.use_dpm_solver = True
        else:
            self.use_dpm_solver = False

        self.use_karras_sigmas = config_dict["use_karras_sigmas"]
        if self.use_karras_sigmas == "True":
            self.use_karras_sigmas = True
        else:
            self.use_karras_sigmas = False
        self.scheduler_algorithm_type = config_dict["scheduler_algorithm_type"]
        if config_dict["solver_order"] != "None":
            self.solver_order = int(config_dict["solver_order"])
        else:
            self.solver_order = None

        self.cfg_scale = float(config_dict["cfg_scale"])
        self.width = int(config_dict["width"])
        self.height = int(config_dict["height"])
        self.output_type = config_dict["output_type"]
        self.aesthetic_score = float(config_dict["aesthetic_score"])
        self.negative_aesthetic_score = float(config_dict["negative_aesthetic_score"])

        self.save_latent_simple = config_dict["save_latent_simple"]
        if self.save_latent_simple == "True":
            self.save_latent_simple = True
            print("use callback save_latent_simple")
        else:
            self.save_latent_simple = False

        self.save_latent_overstep = config_dict["save_latent_overstep"]
        if self.save_latent_overstep == "True":
            self.save_latent_overstep = True
            print("use callback save_latent_overstep")
        else:
            self.save_latent_overstep = False

        self.save_latent_approximation = config_dict["save_latent_approximation"]
        if self.save_latent_approximation == "True":
            self.save_latent_approximation = True
            print("use callback save_latent_approximation")
        else:
            self.save_latent_approximation = False

        self.save_predict_skip_x0 = config_dict["save_predict_skip_x0"]
        if self.save_predict_skip_x0 == "True":
            self.save_predict_skip_x0 = True
            print("use callback save_predict_skip_x0")
        else:
            self.save_predict_skip_x0 = False

        self.use_callback = False
        if self.save_latent_simple or self.save_latent_overstep or self.save_latent_approximation or self.save_predict_skip_x0:
            self.use_callback = True

        if self.save_predict_skip_x0:
            if self.save_latent_simple or self.save_latent_overstep:
                raise ValueError("save_predict_skip_x0 and (save_latent_simple or save_latent_overstep) cannot be set at the same time")
            if self.use_dpm_solver:
                raise ValueError("save_predict_skip_x0 and use_dpm_solver cannot be set at the same time")
        else:
            if self.save_latent_simple and self.save_latent_overstep:
                raise ValueError("save_latent_simple and save_latent_overstep cannot be set at the same time")

        self.base , self.refiner = self.preprepare_model()


    def preprepare_model(self):
        controlnet = ControlNetModel.from_pretrained(
                self.controlnet_path,
                use_safetensors=True,
                torch_dtype=torch.float16)

        if self.VAE_FLAG:
            vae = AutoencoderKL.from_pretrained(
                self.vae_model_path,
                torch_dtype=torch.float16)

            base = StableDiffusionXLControlNetPipeline.from_pretrained(
                self.base_model_path,
                controlnet=controlnet,
                vae=vae,
                torch_dtype=torch.float16,
                variant="fp16",
                use_safetensors=True
            )
            base.to(self.device)

            if self.REFINER_FLAG:
                refiner = DiffusionPipeline.from_pretrained(
                    self.refiner_model_path,
                    text_encoder_2=base.text_encoder_2,
                    vae=vae,
                    requires_aesthetics_score=True,
                    torch_dtype=torch.float16,
                    variant="fp16",
                    use_safetensors=True
                )

                refiner.enable_model_cpu_offload()
            else:
                refiner = None

        else:
            base = StableDiffusionXLControlNetPipeline.from_pretrained(
                self.base_model_path,
                controlnet=controlnet,
                torch_dtype=torch.float16,
                variant="fp16",
                use_safetensors=True
            )
            base.to(self.device, torch.float16)

            if self.REFINER_FLAG:
                refiner = DiffusionPipeline.from_pretrained(
                    self.refiner_model_path,
                    text_encoder_2=base.text_encoder_2,
                    requires_aesthetics_score=True,
                    torch_dtype=torch.float16,
                    variant="fp16",
                    use_safetensors=True
                )

                refiner.enable_model_cpu_offload()
            else:
                refiner = None

        if self.LORA_FLAG:
            base.load_lora_weights(self.lora_weight_path)



        if self.use_dpm_solver:
            if self.solver_order is not None:
                base.scheduler = DPMSolverMultistepScheduler.from_config(
                        base.scheduler.config,
                        use_karras_sigmas=self.use_karras_sigmas,
                        Algorithm_type =self.scheduler_algorithm_type,
                        solver_order=self.solver_order,
                        )

            else:
                base.scheduler = DPMSolverMultistepScheduler.from_config(
                        base.scheduler.config,
                        use_karras_sigmas=self.use_karras_sigmas,
                        Algorithm_type =self.scheduler_algorithm_type,
                        )
        else:
            base.scheduler = EulerAncestralDiscreteScheduler.from_config(base.scheduler.config)

        return base, refiner

    def prepare_referimage(self,input_refer_image_path,output_refer_image_path, low_threshold = 100, high_threshold = 200):

        mode = None
        if self.control_mode is not None:
            mode = self.control_mode
        else:
            raise ValueError("control_mode is not set")

        def prepare_openpose(input_refer_image_path,output_refer_image_path, mode):

            # 初期画像の準備
            init_image = load_image(input_refer_image_path)
            init_image = init_image.resize((self.width, self.height))

            processor = Processor(mode)
            processed_image = processor(init_image, to_pil=True)

            processed_image.save(output_refer_image_path)




        def prepare_canny(input_refer_image_path,output_refer_image_path, low_threshold = 100, high_threshold = 200):
            init_image = load_image(input_refer_image_path)
            init_image = init_image.resize((self.width, self.height))

            # コントロールイメージを作成するメソッド
            def make_canny_condition(image, low_threshold = 100, high_threshold = 200):
                image = np.array(image)
                image = cv2.Canny(image, low_threshold, high_threshold)
                image = image[:, :, None]
                image = np.concatenate([image, image, image], axis=2)
                return Image.fromarray(image)

            control_image = make_canny_condition(init_image, low_threshold, high_threshold)
            control_image.save(output_refer_image_path)

        def prepare_depthmap(input_refer_image_path,output_refer_image_path):

            # 初期画像の準備
            init_image = load_image(input_refer_image_path)
            init_image = init_image.resize((self.width, self.height))
            processor = Processor("depth_midas")
            depth_image = processor(init_image, to_pil=True)
            depth_image.save(output_refer_image_path)

        def prepare_zoe_depthmap(input_refer_image_path,output_refer_image_path):

            torch.hub.help(
                "intel-isl/MiDaS",
                "DPT_BEiT_L_384",
                force_reload=True
                )
            model_zoe_n = torch.hub.load(
                "isl-org/ZoeDepth",
                "ZoeD_NK",
                pretrained=True
                ).to("cuda")

            init_image = load_image(input_refer_image_path)
            init_image = init_image.resize((self.width, self.height))

            depth_numpy = model_zoe_n.infer_pil(init_image)  # return: numpy.ndarray

            from zoedepth.utils.misc import colorize
            colored = colorize(depth_numpy) # numpy.ndarray => numpy.ndarray

            # gamma correction
            img = colored / 255
            img = np.power(img, 2.2)
            img = (img * 255).astype(np.uint8)

            Image.fromarray(img).save(output_refer_image_path)


        if "openpose" in mode:
            prepare_openpose(input_refer_image_path,output_refer_image_path, mode)
        elif mode == "canny":
            prepare_canny(input_refer_image_path,output_refer_image_path, low_threshold = low_threshold, high_threshold = high_threshold)
        elif mode == "depth":
            prepare_depthmap(input_refer_image_path,output_refer_image_path)
        elif mode == "zoe_depth":
            prepare_zoe_depthmap(input_refer_image_path,output_refer_image_path)
        elif mode == "tile" or mode == "scribble":
            init_image = load_image(input_refer_image_path)
            init_image.save(output_refer_image_path)
        else:
            raise ValueError("control_mode is not set")


    def generate_image(self, prompt, neg_prompt, image_path, seed = None, controlnet_conditioning_scale = 1.0):
        def decode_tensors(pipe, step, timestep, callback_kwargs):
            if self.save_latent_simple or self.save_predict_skip_x0:
                callback_kwargs = decode_tensors_simple(pipe, step, timestep, callback_kwargs)
            elif self.save_latent_overstep:
                callback_kwargs = decode_tensors_residual(pipe, step, timestep, callback_kwargs)
            else:
                raise ValueError("self.save_predict_skip_x0 or save_latent_simple or save_latent_overstep must be set or 'save_latent_approximation = False'")
            return callback_kwargs


        def decode_tensors_simple(pipe, step, timestep, callback_kwargs):
            latents = callback_kwargs["latents"]
            skip_x0 = callback_kwargs["skip_x0"]
            imege = None
            prefix = None
            if not self.save_predict_skip_x0:
                prefix = "latents"
                if self.save_latent_simple and not self.save_latent_approximation:
                    image = latents_to_rgb_vae(latents,pipe)
                elif self.save_latent_approximation:
                    image = latents_to_rgb_approximation(latents,pipe)
                else:
                    raise ValueError("save_latent_simple or save_latent_approximation is not set")
            else:
                prefix = "predicted_x0"
                image = latents_to_rgb_vae(skip_x0,pipe)
                
            gettime = time.time()
            formatted_time_human_readable = time.strftime("%Y%m%d_%H%M%S", time.localtime(gettime))
            image.save(f"./outputs/{prefix}_{formatted_time_human_readable}_{step}_{timestep}.png")

            return callback_kwargs

        def decode_tensors_residual(pipe, step, timestep, callback_kwargs):
            latents = callback_kwargs["latents"]
            if step > 0:
                residual = latents - self.last_latents
                goal = self.last_latents + residual * ((self.last_timestep) / (self.last_timestep - timestep))
                #print( ((self.last_timestep) / (self.last_timestep - timestep)))
            else:
                goal = latents

            if self.save_latent_overstep and not self.save_latent_approximation:
                image = latents_to_rgb_vae(goal,pipe)
            elif self.save_latent_approximation:
                image = latents_to_rgb_approximation(goal,pipe)
            else:
                raise ValueError("save_latent_simple or save_latent_approximation is not set")

            gettime = time.time()
            formatted_time_human_readable = time.strftime("%Y%m%d_%H%M%S", time.localtime(gettime))
            image.save(f"./outputs/latent_{formatted_time_human_readable}_{step}_{timestep}.png")

            self.last_latents = latents
            self.last_step = step
            self.last_timestep = timestep

            if timestep == 0:
                self.last_latents = None
                self.last_step = -1
                self.last_timestep = 100

            return callback_kwargs

        def latents_to_rgb_vae(latents,pipe):

            pipe.upcast_vae()
            latents = latents.to(next(iter(pipe.vae.post_quant_conv.parameters())).dtype)
            images = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
            images = pipe.image_processor.postprocess(images, output_type='pil')
            pipe.vae.to(dtype=torch.float16)

            return StableDiffusionXLPipelineOutput(images=images).images[0]

        def latents_to_rgb_approximation(latents, pipe):
            weights = (
                (60, -60, 25, -70),
                (60,  -5, 15, -50),
                (60,  10, -5, -35)
            )

            weights_tensor = torch.t(torch.tensor(weights, dtype=latents.dtype).to(latents.device))
            biases_tensor = torch.tensor((150, 140, 130), dtype=latents.dtype).to(latents.device)
            rgb_tensor = torch.einsum("...lxy,lr -> ...rxy", latents, weights_tensor) + biases_tensor.unsqueeze(-1).unsqueeze(-1)
            image_array = rgb_tensor.clamp(0, 255)[0].byte().cpu().numpy()
            image_array = image_array.transpose(1, 2, 0)  # Change the order of dimensions

            return Image.fromarray(image_array)

        if seed is not None:
            self.generator = torch.Generator(device=self.device).manual_seed(seed)

        control_image = load_image(image_path)

        image = None
        if self.use_callback:
            if self.LORA_FLAG:
                if self.REFINER_FLAG:
                    image = self.base(
                        prompt=prompt,
                        negative_prompt=neg_prompt,
                        image=control_image,
                        cfg_scale=self.cfg_scale,
                        controlnet_conditioning_scale=controlnet_conditioning_scale,
                        num_inference_steps=self.n_steps,
                        denoising_end=self.high_noise_frac,
                        output_type="latent",
                        width = self.width,
                        height = self.height,
                        generator=self.generator,
                        cross_attention_kwargs={"scale": self.lora_scale},
                        callback_on_step_end=decode_tensors,
                        callback_on_step_end_tensor_inputs=["latents", "skip_x0"],
                        ).images[0]
                    image = self.refiner(
                        prompt=prompt,
                        negative_prompt=neg_prompt,
                        cfg_scale=self.cfg_scale,
                        aesthetic_score = self.aesthetic_score,
                        negative_aesthetic_score = self.negative_aesthetic_score,
                        num_inference_steps=self.n_steps,
                        denoising_start=self.high_noise_frac,
                        callback_on_step_end=decode_tensors,
                        callback_on_step_end_tensor_inputs=["latents", "skip_x0"],
                        image=image[None, :]
                        ).images[0]
                #refiner を利用しない場合
                else:
                    image = self.base(
                        prompt=prompt,
                        negative_prompt=neg_prompt,
                        image=control_image,
                        cfg_scale=self.cfg_scale,
                        controlnet_conditioning_scale=controlnet_conditioning_scale,
                        num_inference_steps=self.n_steps,
                        denoising_end=self.high_noise_frac,
                        output_type=self.output_type,
                        width = self.width,
                        height = self.height,
                        generator=self.generator,
                        callback_on_step_end=decode_tensors,
                        callback_on_step_end_tensor_inputs=["latents", "skip_x0"],
                        cross_attention_kwargs={"scale": self.lora_scale},
                        ).images[0]
            #LORAを利用しない場合
            else:
                if self.REFINER_FLAG:
                    image = self.base(
                        prompt=prompt,
                        negative_prompt=neg_prompt,
                        image=control_image,
                        cfg_scale=self.cfg_scale,
                        controlnet_conditioning_scale=controlnet_conditioning_scale,
                        num_inference_steps=self.n_steps,
                        denoising_end=self.high_noise_frac,
                        output_type="latent",
                        width = self.width,
                        height = self.height,
                        callback_on_step_end=decode_tensors,
                        callback_on_step_end_tensor_inputs=["latents", "skip_x0"],
                        generator=self.generator
                        ).images[0]
                    image = self.refiner(
                        prompt=prompt,
                        negative_prompt=neg_prompt,
                        cfg_scale=self.cfg_scale,
                        aesthetic_score = self.aesthetic_score,
                        negative_aesthetic_score = self.negative_aesthetic_score,
                        num_inference_steps=self.n_steps,
                        denoising_start=self.high_noise_frac,
                        callback_on_step_end=decode_tensors,
                        callback_on_step_end_tensor_inputs=["latents", "skip_x0"],
                        image=image[None, :]
                        ).images[0]
                #refiner を利用しない場合
                else:
                    image = self.base(
                        prompt=prompt,
                        negative_prompt=neg_prompt,
                        image=control_image,
                        cfg_scale=self.cfg_scale,
                        controlnet_conditioning_scale=controlnet_conditioning_scale,
                        num_inference_steps=self.n_steps,
                        denoising_end=self.high_noise_frac,
                        output_type=self.output_type,
                        width = self.width,
                        height = self.height,
                        callback_on_step_end=decode_tensors,
                        callback_on_step_end_tensor_inputs=["latents", "skip_x0"],
                        generator=self.generator
                        ).images[0]
        #latentを保存しない場合
        else:
            if self.LORA_FLAG:
                if self.REFINER_FLAG:
                    image = self.base(
                        prompt=prompt,
                        negative_prompt=neg_prompt,
                        image=control_image,
                        cfg_scale=self.cfg_scale,
                        controlnet_conditioning_scale=controlnet_conditioning_scale,
                        num_inference_steps=self.n_steps,
                        denoising_end=self.high_noise_frac,
                        output_type="latent",
                        width = self.width,
                        height = self.height,
                        generator=self.generator,
                        cross_attention_kwargs={"scale": self.lora_scale},
                        ).images[0]
                    image = self.refiner(
                        prompt=prompt,
                        negative_prompt=neg_prompt,
                        cfg_scale=self.cfg_scale,
                        aesthetic_score = self.aesthetic_score,
                        negative_aesthetic_score = self.negative_aesthetic_score,
                        num_inference_steps=self.n_steps,
                        denoising_start=self.high_noise_frac,
                        image=image[None, :]
                        ).images[0]
                # refiner を利用しない場合
                else:
                    image = self.base(
                        prompt=prompt,
                        negative_prompt=neg_prompt,
                        image=control_image,
                        cfg_scale=self.cfg_scale,
                        controlnet_conditioning_scale=controlnet_conditioning_scale,
                        num_inference_steps=self.n_steps,
                        denoising_end=self.high_noise_frac,
                        output_type=self.output_type,
                        width = self.width,
                        height = self.height,
                        generator=self.generator,
                        cross_attention_kwargs={"scale": self.lora_scale},
                        ).images[0]
            # LORAを利用しない場合
            else:
                if self.REFINER_FLAG:
                    image = self.base(
                        prompt=prompt,
                        negative_prompt=neg_prompt,
                        image=control_image,
                        cfg_scale=self.cfg_scale,
                        controlnet_conditioning_scale=controlnet_conditioning_scale,
                        num_inference_steps=self.n_steps,
                        denoising_end=self.high_noise_frac,
                        output_type="latent",
                        width = self.width,
                        height = self.height,
                        generator=self.generator
                        ).images[0]
                    image = self.refiner(
                        prompt=prompt,
                        negative_prompt=neg_prompt,
                        cfg_scale=self.cfg_scale,
                        aesthetic_score = self.aesthetic_score,
                        negative_aesthetic_score = self.negative_aesthetic_score,
                        num_inference_steps=self.n_steps,
                        denoising_start=self.high_noise_frac,
                        image=image[None, :]
                        ).images[0]
                # refiner を利用しない場合
                else:
                    image = self.base(
                        prompt=prompt,
                        negative_prompt=neg_prompt,
                        image=control_image,
                        cfg_scale=self.cfg_scale,
                        controlnet_conditioning_scale=controlnet_conditioning_scale,
                        num_inference_steps=self.n_steps,
                        denoising_end=self.high_noise_frac,
                        output_type=self.output_type,
                        width = self.width,
                        height = self.height,
                        generator=self.generator
                        ).images[0]

        return image

基本的にはpart10の内容と同一ではあるが、下記の部分が異なるため説明します。

        def generate_image(self, prompt, neg_prompt, image_path, seed = None, controlnet_conditioning_scale = 1.0):
            ・・・・・・・
                    image = self.base(
                        prompt=prompt,
                        negative_prompt=neg_prompt,
                        image=control_image,
                        cfg_scale=self.cfg_scale,
                        controlnet_conditioning_scale=controlnet_conditioning_scale,
                        num_inference_steps=self.n_steps,
                        denoising_end=self.high_noise_frac,
                        output_type=self.output_type,
                        width = self.width,
                        height = self.height,
                        generator=self.generator,
                        callback_on_step_end=decode_tensors,
                        callback_on_step_end_tensor_inputs=["latents", "skip_x0"],
                        cross_attention_kwargs={"scale": self.lora_scale},
                        ).images[0]
            ・・・・・・・

画像を生成する際に、上記のような形で生成しますが、特に、callback_on_step_end_tensor_inputs=["latents", "skip_x0"],において"skip_x0"が追加されています。

これは、Samplerが各Stepごとに予測するクリーン画像をコールバック関数で受け取るための引数になります。しかしながら、通常のDiffusersモジュールでは、この潜在表現は受け取るとができないので、Forkして改造しています。

それが下記のリポジトリです。
https://github.com/personabb/diffusers-preview_latents

変更箇所は下記を見ればわかります。わかりやすさ重視のため最低限の変更に抑えています。
https://github.com/huggingface/diffusers/compare/main...personabb:diffusers-preview_latents:main

まず変更点について解説し、その後なぜそこを変更したのかについて説明します。
まず一個目の変更点は下記です。

diffusers-preview_latents/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py

    _callback_tensor_inputs = [
        "latents",
        "prompt_embeds",
        "negative_prompt_embeds",
        "add_text_embeds",
        "add_time_ids",
        "negative_pooled_prompt_embeds",
        "negative_add_time_ids",
+       "skip_x0",
    ]

上記はcallback関数に渡すことが可能な変数のリストになります。ここに書かれていない変数をcallback関数で受け取ろうとすると、StableDiffusionXLControlNetPipelineクラスの__call__メソッド内で呼ばれるcheck_inputsメソッドでエラーになります。

続いての変更点は下記です。

diffusers-preview_latents/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+ latent_all = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=True)
+ latents = latent_all.prev_sample
+ skip_x0 = None
+ if hasattr(latent_all, 'pred_original_sample'):
+     skip_x0 = latent_all.pred_original_sample

self.scheduler.stepは本パイプラインにて利用しているSamplerクラスにて定義されているstepメソッドを読んでいます。
stepメソッドはreturn_dict引数の真偽により、出力が変わります。
元のDiffuserの通り、return_dict=Falseとすると、各ステップごとのノイズ混じりの潜在表現のみを取得できます。
一方で、return_dict=Trueとすると、各ステップごとのノイズ混じりの潜在表現に加えて、各ステップにてSamplerが予測したクリーン画像(x_0)も取得することができます。ただし、Samplerの種類によっては取得できないSamplerもあります。例えば、よく利用されているDPMSolverMultistepSchedulerは取得できません。一方で、SamplerとしてEulerAncestralDiscreteSchedulerを利用している場合は、取得可能なので今回はこちらを利用します。

では、実際にstepメソッドについて中身を見ていきます。

該当部分は下記です

if not return_dict:
            return (prev_sample,)

        return EulerAncestralDiscreteSchedulerOutput(
            prev_sample=prev_sample, pred_original_sample=pred_original_sample
        )

上記で解説した通り、引数のreturn_dictがTrueの場合は、returnとしてEulerAncestralDiscreteSchedulerOutputが返されます。

https://huggingface.co/docs/diffusers/en/api/schedulers/euler_ancestral
上記の公式のドキュメントの通り、EulerAncestralDiscreteSchedulerOutputは二つのパラメータを持ちます。
deeplにて翻訳した文章が下記です。

prev_sample (torch.Tensor of shape (batch_size, num_channels, height, width) for images) - 直前のタイムステップで計算されたサンプル(x_{t-1}).
pred_original_sample (torch.Tensor of shape (batch_size, num_channels, height, width) for images) - 現在のタイムステップのモデル出力に基づく,ノイズ除去予測サンプル(x_{0}).

上記の通り、EulerAncestralDiscreteSchedulerOutputクラスのpred_original_sample属性を取得することができれば、クリーン画像を取得することができます。
そのために下記のようにDiffuserに変更を加えたわけです。

diffusers-preview_latents/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+ latent_all = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=True)
+ latents = latent_all.prev_sample
+ skip_x0 = None
+ if hasattr(latent_all, 'pred_original_sample'):
+     skip_x0 = latent_all.pred_original_sample
DPMSolverMultistepScheduler

ちなみにですがDPMSolverMultistepScheduler公式のドキュメントを見ても分かるとおり、DPMSolverMultistepSchedulerの出力で利用されているSchedulerOutputクラスにはprev_sampleの属性はあるが、pred_original_sampleの属性はないため、各ステップごとのクリーン画像を取得することができないため、設定でuse_dpm_solver = Falseを用意しています。

逆に言えば、出力のクラスを確認して、pred_original_sample属性が存在するSamplerであれば、EulerAncestralDiscreteSchedulerでなくても、同様に各ステップごとのクリーン画像を取得できます。

今回、EulerAncestralDiscreteSchedulerを利用したのは、使っているモデルanimagine-xl-3.1の「Recommended settings」の章にて、使用を推奨されていたSamplerだからです。

it’s recommended to use a lower classifier-free guidance (CFG Scale) of around 5-7, sampling steps below 30, and to use Euler Ancestral (Euler a) as a sampler.

https://huggingface.co/cagliostrolab/animagine-xl-3.1

続いて、潜在表現を画像化する部分に関して解説します。

まずは該当部分の全文を表示します。

潜在表現から画像を再構成するコールバック関数
        def decode_tensors(pipe, step, timestep, callback_kwargs):
            if self.save_latent_simple or self.save_predict_skip_x0:
                callback_kwargs = decode_tensors_simple(pipe, step, timestep, callback_kwargs)
            elif self.save_latent_overstep:
                callback_kwargs = decode_tensors_residual(pipe, step, timestep, callback_kwargs)
            else:
                raise ValueError("self.save_predict_skip_x0 or save_latent_simple or save_latent_overstep must be set or 'save_latent_approximation = False'")
            return callback_kwargs


        def decode_tensors_simple(pipe, step, timestep, callback_kwargs):
            latents = callback_kwargs["latents"]
            skip_x0 = callback_kwargs["skip_x0"]
            imege = None
            prefix = None
            if not self.save_predict_skip_x0:
                prefix = "latents"
                if self.save_latent_simple and not self.save_latent_approximation:
                    image = latents_to_rgb_vae(latents,pipe)
                elif self.save_latent_approximation:
                    image = latents_to_rgb_approximation(latents,pipe)
                else:
                    raise ValueError("save_latent_simple or save_latent_approximation is not set")
            else:
                prefix = "predicted_x0"
                image = latents_to_rgb_vae(skip_x0,pipe)
                
            gettime = time.time()
            formatted_time_human_readable = time.strftime("%Y%m%d_%H%M%S", time.localtime(gettime))
            image.save(f"./outputs/{prefix}_{formatted_time_human_readable}_{step}_{timestep}.png")

            return callback_kwargs

        def decode_tensors_residual(pipe, step, timestep, callback_kwargs):
            latents = callback_kwargs["latents"]
            if step > 0:
                residual = latents - self.last_latents
                goal = self.last_latents + residual * ((self.last_timestep) / (self.last_timestep - timestep))
                #print( ((self.last_timestep) / (self.last_timestep - timestep)))
            else:
                goal = latents

            if self.save_latent_overstep and not self.save_latent_approximation:
                image = latents_to_rgb_vae(goal,pipe)
            elif self.save_latent_approximation:
                image = latents_to_rgb_approximation(goal,pipe)
            else:
                raise ValueError("save_latent_simple or save_latent_approximation is not set")

            gettime = time.time()
            formatted_time_human_readable = time.strftime("%Y%m%d_%H%M%S", time.localtime(gettime))
            image.save(f"./outputs/latent_{formatted_time_human_readable}_{step}_{timestep}.png")

            self.last_latents = latents
            self.last_step = step
            self.last_timestep = timestep

            if timestep == 0:
                self.last_latents = None
                self.last_step = -1
                self.last_timestep = 100

            return callback_kwargs

        def latents_to_rgb_vae(latents,pipe):

            pipe.upcast_vae()
            latents = latents.to(next(iter(pipe.vae.post_quant_conv.parameters())).dtype)
            images = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
            images = pipe.image_processor.postprocess(images, output_type='pil')
            pipe.vae.to(dtype=torch.float16)

            return StableDiffusionXLPipelineOutput(images=images).images[0]

        def latents_to_rgb_approximation(latents, pipe):
            weights = (
                (60, -60, 25, -70),
                (60,  -5, 15, -50),
                (60,  10, -5, -35)
            )

            weights_tensor = torch.t(torch.tensor(weights, dtype=latents.dtype).to(latents.device))
            biases_tensor = torch.tensor((150, 140, 130), dtype=latents.dtype).to(latents.device)
            rgb_tensor = torch.einsum("...lxy,lr -> ...rxy", latents, weights_tensor) + biases_tensor.unsqueeze(-1).unsqueeze(-1)
            image_array = rgb_tensor.clamp(0, 255)[0].byte().cpu().numpy()
            image_array = image_array.transpose(1, 2, 0)  # Change the order of dimensions

            return Image.fromarray(image_array)

一つずつ説明します。
まず、各ステップごとにコールバック関数として呼ばれる関数は、decode_tensors関数になります。この関数は、パイプラインの__call__メソッドの引数のcallback_on_step_end=decode_tensors,として指定しています。

decode_tensors関数は下記のような関数です。

def decode_tensors(pipe, step, timestep, callback_kwargs):
            if self.save_latent_simple or self.save_predict_skip_x0:
                callback_kwargs = decode_tensors_simple(pipe, step, timestep, callback_kwargs)
            elif self.save_latent_overstep:
                callback_kwargs = decode_tensors_residual(pipe, step, timestep, callback_kwargs)
            else:
                raise ValueError("self.save_predict_skip_x0 or save_latent_simple or save_latent_overstep must be set or 'save_latent_approximation = False'")
            return callback_kwargs

コールバック関数は(pipe, step, timestep, callback_kwargs)を引数とし、returnとしてcallback_kwargsを返す必要がある。このcallback_kwargsはコールバック関数の引数と同じ形である必要がある。具体的には辞書として、下記の名前がキーとなっているものが必要です。

    _callback_tensor_inputs = [
        "latents",
        "prompt_embeds",
        "negative_prompt_embeds",
        "add_text_embeds",
        "add_time_ids",
        "negative_pooled_prompt_embeds",
        "negative_add_time_ids",
        "skip_x0",
    ]

さらに例えば、callback_kwargs["latents"] = 2 * latentsのように、値を変更して、returnすることで、元の画像生成の処理で利用しているlatentsの値も操作することができます。

証拠

callback関数に関連する箇所はここ

diffusers-preview_latents/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
if callback_on_step_end is not None:
    callback_kwargs = {}
    for k in callback_on_step_end_tensor_inputs:
        callback_kwargs[k] = locals()[k]
    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

    latents = callback_outputs.pop("latents", latents)
    prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
    negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
    add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
    negative_pooled_prompt_embeds = callback_outputs.pop(
        "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
    )
    add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
    negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)

実際にコールバック関数が動くのは

diffusers-preview_latents/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

であり、その出力をpopする形で、latentsprompt_embedsなどを各ステップごとに変更することができるように設計されているため、コールバック関数を駆使することで色々なことができそうですね。

また、_callback_tensor_inputsに追加することで、使われているどんな変数でも取得することができます。なぜなら下記で、__call__メソッドのスコープ内の変数を全て取得できるようになっているからです。

diffusers-preview_latents/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
for k in callback_on_step_end_tensor_inputs:
    callback_kwargs[k] = locals()[k]

では、元のコールバック関数の説明に戻ります。

コールバック関数では、設定ファイルの設定に応じて、二つの関数のうちどちらかが起動するようになっています。

if self.save_latent_simple or self.save_predict_skip_x0:
    callback_kwargs = decode_tensors_simple(pipe, step, timestep, callback_kwargs)
elif self.save_latent_overstep:
    callback_kwargs = decode_tensors_residual(pipe, step, timestep, callback_kwargs)

すなわち、save_latent_simplesave_predict_skip_x0がTrueの場合はdecode_tensors_simple関数が起動し、save_latent_overstepがTrueの場合はdecode_tensors_residualが起動します。

decode_tensors_simple関数は下記のように定義されます。

def decode_tensors_simple(pipe, step, timestep, callback_kwargs):
    latents = callback_kwargs["latents"]
    skip_x0 = callback_kwargs["skip_x0"]
    imege = None
    prefix = None
    if not self.save_predict_skip_x0:
        prefix = "latents"
        if self.save_latent_simple and not self.save_latent_approximation:
            image = latents_to_rgb_vae(latents,pipe)
        elif self.save_latent_approximation:
            image = latents_to_rgb_approximation(latents,pipe)
        else:
            raise ValueError("save_latent_simple or save_latent_approximation is not set")
    else:
        prefix = "predicted_x0"
        image = latents_to_rgb_vae(skip_x0,pipe)
        
    gettime = time.time()
    formatted_time_human_readable = time.strftime("%Y%m%d_%H%M%S", time.localtime(gettime))
    image.save(f"./outputs/{prefix}_{formatted_time_human_readable}_{step}_{timestep}.png")

    return callback_kwargs

この中でも同様に設定ファイルの設定に応じて、使う関数を変更させて潜在変数の処理を行い、その上で、生成された画像を「outputs」フォルダに保存しています。

重要なのは、

image = latents_to_rgb_vae(latents,pipe)

image = latents_to_rgb_approximation(latents,pipe)

image = latents_to_rgb_vae(skip_x0,pipe)

です。

latents_to_rgb_vae(latents,pipe)は各ステップのノイズ混じりの潜在表現を受けとり、VAEで画像を再構成しています。
latents_to_rgb_vae(skip_x0,pipe)も入力が、各ステップでの予測されたクリーン画像になっているだけで本質は同じです。

latents_to_rgb_vae関数は下記のように定義されており、基本的にはDiffuserモジュールの書き方を踏襲して作っています。

def latents_to_rgb_vae(latents,pipe):

    pipe.upcast_vae()
    latents = latents.to(next(iter(pipe.vae.post_quant_conv.parameters())).dtype)
    images = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
    images = pipe.image_processor.postprocess(images, output_type='pil')
    pipe.vae.to(dtype=torch.float16)

    return StableDiffusionXLPipelineOutput(images=images).images[0]

違うのはpipe.vae.to(dtype=torch.float16)の部分です。
VAEに潜在表現を通す前にupcastしてfloat16からfloat32に型変換をする必要があります。
VAEはコールバック関数で利用するものも、パイプラインの__call__メソッド内で利用するものも同じものを利用するため、コールバック関数内で変更した型は元に戻す必要があるので、加えています。

ここで戻さないとlatentsの型とVAEの型が合わないため、最終ステップでの画像を生成する際にエラーが発生します。

エラーの理由

発生するエラーは下記になります。

RuntimeError: Input type (c10::Half) and bias type (float) should be the same

VAEの型がfloat32になっていると下記の部分needs_upcastingがFalseになってしまいます。

diffusers-preview_latents/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
if not output_type == "latent":
    # make sure the VAE is in float32 mode, as it overflows in float16
    needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast

その場合、次の部分の処理が行われないため、VAEはfloat32なのに、latentsはfloat16のまま処理を行うことになります。

diffusers-preview_latents/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
if needs_upcasting:
    self.upcast_vae()
    latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)

結果、型が合わないエラーが発生します。

続いて、latents_to_rgb_approximation(latents,pipe)です。
こちらも各ステップのノイズ混じり潜在表現から画像を再構成する関数ですが、再構成の方法が異なります。
これまでの方法はVAEのDecoderに通すことで再構成していますが、今回は潜在表現に対して線形処理を行うことで、画像を線形近似して表示しています。

https://huggingface.co/blog/TimothyAlexisVass/explaining-the-sdxl-latent-space
上記の記事に記載されているものを、そのまま利用して実装していますが、中身についてはちゃんと読んでいないので理解できていません。
わかっているのはこの処理のように線形近似すると、潜在表現から画像っぽいものが再構成されるということです。この調査をされた方は本当にすごいですね・・・

以上で、一旦decode_tensors_simple関数の説明は終わります。
次に。decode_tensors_residual関数についてです。

def decode_tensors_residual(pipe, step, timestep, callback_kwargs):
    latents = callback_kwargs["latents"]
    if step > 0:
        residual = latents - self.last_latents
        goal = self.last_latents + residual * ((self.last_timestep) / (self.last_timestep - timestep))
        #print( ((self.last_timestep) / (self.last_timestep - timestep)))
    else:
        goal = latents

    if self.save_latent_overstep and not self.save_latent_approximation:
        image = latents_to_rgb_vae(goal,pipe)
    elif self.save_latent_approximation:
        image = latents_to_rgb_approximation(goal,pipe)
    else:
        raise ValueError("save_latent_simple or save_latent_approximation is not set")

    gettime = time.time()
    formatted_time_human_readable = time.strftime("%Y%m%d_%H%M%S", time.localtime(gettime))
    image.save(f"./outputs/latent_{formatted_time_human_readable}_{step}_{timestep}.png")

    self.last_latents = latents
    self.last_step = step
    self.last_timestep = timestep

    if timestep == 0:
        self.last_latents = None
        self.last_step = -1
        self.last_timestep = 100

    return callback_kwargs

こちらに関しても基本的にはdecode_tensors_simple関数と同じですが、下記部分だけ異なります。

if step > 0:
    residual = latents - self.last_latents
    goal = self.last_latents + residual * ((self.last_timestep) / (self.last_timestep - timestep))
else:
    goal = latents

ここで実施したいことは、今回の潜在表現と前回の潜在表現の差分residualには、今回のstepにて得られた入力に対する勾配によって更新が行われた量が格納されることになります。

その変化量をより大きくして更新することで、早い段階から生成画像を確認できるのではないかと考えた次第です。
どの程度変化量を大きくするかというのは下記の式です

((self.last_timestep) / (self.last_timestep - timestep))

これは今回の更新で進んだtimestep数 (self.last_timestep - timestep)での更新量を残りのtimestep数(self.last_timestep) (最大1000)倍しています。

以上がmodule/module_sdc.pyにおけるpart10の記事との違いになります。

実験結果

ここからは、上記のコードによってGoogle Colabでパラメータを変更して、様々な実験を実施したため、その詳細を記載します。

前提条件

前提として下記の設定を継承します。後述する実験において特に記載がない場合は、この設定が継承されていると考えてください。

5セル目

config_text = """
[SDXLC]
device = auto
n_steps=28
high_noise_frac=None
seed=42

vae_model_path = None
base_model_path = Asahina2K/Animagine-xl-3.1-diffuser-variant-fp16
refiner_model_path = None

controlnet_path = diffusers/controlnet-depth-sdxl-1.0

control_mode = depth

lora_weight_path = ./inputs/DreamyvibesartstyleSDXL.safetensors
lora_scale = 1.0

use_dpm_solver = False
use_karras_sigmas = True
scheduler_algorithm_type = dpmsolver++
solver_order = 2

cfg_scale = 7.0

width = 832
height = 1216
output_type = pil
aesthetic_score = 6
negative_aesthetic_score = 2.5

save_latent_simple = False
save_latent_overstep = False
save_latent_approximation = False

save_predict_skip_x0 = False

"""

with open("configs/config.ini", "w", encoding="utf-8") as f:
  f.write(config_text)

上記の設定の通り、LoRAとControlNetのDepthを利用します

6セル目

main_prompt = """
1 girl ,Yellowish-white hair ,short hair ,red small ribbon,red eyes,red hat ,school uniform ,solo ,smile ,upper body ,Anime ,Japanese,best quality,high quality,ultra highres,ultra quality
"""

use_lora = True
if use_lora:
  main_prompt += ", Dreamyvibes Artstyle"

negative_prompt="""
nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]
"""

input_refer_image_path = "./inputs/refer.webp"
output_refer_image_path = "./inputs/refer.png"

8セル目

controlnet_conditioning_scale = 0.7

ControNetに入力する画像

参照画像の変換前

(使った画像がバリバリ版権画像だったため、depth画像だけで失礼します。元の画像はANIMAGINE XL 3.1の公式が提供しているチュートリアルで使われているプロンプトで作成した画像の一つです。(おそらくseed42-47のあたり)

参照画像の変換後

ちなみに、上記の深度マップを"./inputs/refer.png"として保存して、7セル目のsd.prepare_referimageメソッドを実行せずにコメントアウトすることでも、同様の実験を行うことが可能です。

sd = SDXLC()
#sd.prepare_referimage(input_refer_image_path = input_refer_image_path, output_refer_image_path = output_refer_image_path, low_threshold = 100, high_threshold = 200)

実験1

「生成途中の潜在表現からVAEで再構成した画像の表示」を行います。

設定

前提から、下記部分だけ変更する

save_latent_simple = True

結果

大量の画像が保存されるので、途中画像はgifに変換して表示します。
(3MBの制限に抑えるために、画像はかなり劣化しています。ごめんなさい)



実験2

「生成途中の潜在表現を直接線形近似した画像を表示」を行います。

設定

前提から、下記部分だけ変更する

save_latent_simple = True
save_latent_approximation = True

結果

大量の画像が保存されるので、途中画像はgifに変換して表示します。

SDXLの場合、潜在表現は通常の画像サイズの1/8のサイズになります。VAEを通さずに線形近似をしているだけなので、画像サイズは小さくなります。


実験3

「潜在表現に対する各Stepの更新幅を変更して、VAEで再構成した画像の表示」を行います。

設定

前提から、下記部分だけ変更する

use_dpm_solver = True
save_latent_overstep = True

今回の実験では、なぜかEulerAncestralDiscreteSchedulerのSamplerではうまく機能しなかったので。DPMSolverMultistepSchedulerを利用しました。

結果

大量の画像が保存されるので、途中画像はgifに変換して表示します。
(3MBの制限に抑えるために、画像はかなり劣化しています。ごめんなさい)


実験4

「Samplerが生成途中に推定したクリーン画像の表示」を行います。

設定

前提から、下記部分だけ変更する

save_predict_skip_x0 = True

結果

大量の画像が保存されるので、途中画像はgifに変換して表示します。
(3MBの制限に抑えるために、画像はかなり劣化しています。ごめんなさい)


まとめ

以上、ここまででDiffersを利用して、生成途中の潜在表現から画像を再構成してみました。

みたかぎり、実験4の可視化が一番人間にはみやすい可視化だったかなと思います。
実験4を見ると、最初の大まかな画像が決まっていき、細かい箇所は後から少しずつ決まっていくような動きをすることがわかりました。
また、ControlNetに関連する範囲に関しては早いstepから画像が確定しており、背景はそれに合わせて後から生成されるような動きをしていることがわかりました。

一方で、実験4の可視化手法はDPMSolverMultistepSchedulerなどの一部のSamplerでは利用することができません。(利用する方法はあるかもですが、現時点の私の理解力では難しいので、詳しいかたいらっしゃれば教えていただきたいです)
その場合は、実験3の可視化手法も試してみていただけますと嬉しいです。実験4には劣りますが、途中の画像の推移などが見える形になっているかなと思います。

Samplerに応じて、使う可視化手法を変更するのが良さそうかなと思いました。

以上で、終わりです。
ここまで読んでくださり、ありがとうございました。

Discussion