画像生成AIの生成途中経過を可視化したい!【Diffusers】
はじめに
みなさん画像生成AIのStable Diffusionはご存知でしょうか。
私の過去の記事でも画像生成AIをGoogle Colabで簡単に使えるように解説しています。
Stable Diffusionといえば、WebUIが有名ですが、Diffusersというライブラリを利用することで、Pythonを利用して、AIをサービスに組み込んだりすることが可能です。
今回は、画像生成AIが絵を生成するまでの途中経過を可視化してみたいと思います。
(Web UIでは生成途中の画像を可視化する機能はありますが、diffusersではあまり解説されていないように思います)
生成途中の画像を可視化する方法は、複数パターンがあります、
今回は、下記の4パターンの方法で、生成途中を可視化します。
(一番下のものが、皆さんが馴染み深い画像かもしれません)
生成途中の潜在表現からVAEで再構成した画像の表示
生成途中の潜在表現を直接線形近似した画像を表示
(潜在表現の縦横サイズは、元の画像の1/8なので小さい画像です)
潜在表現に対する各Stepの更新幅を変更して、VAEで再構成した画像の表示
(これだけ事情があってSamplerが違うため、生成される画像が異なります)
Samplerが生成途中に推定したクリーン画像の表示
成果物
下記のリポジトリをご覧ください。
今回の実験
下記に実施した実験の内容を記載します。実験結果については最後にご紹介しています。
-
実験1
- 生成途中の潜在表現からVAEで再構成した画像の表示
-
実験2
- 生成途中の潜在表現を直接線形近似した画像を表示
-
実験3
- 潜在表現に対する各Stepの更新幅を変更して、VAEで再構成した画像の表示
-
実験4
- Samplerが生成途中に推定したクリーン画像の表示
事前準備
利用するLoRAモデルを保存する
Part10の記事をご覧ください。
参照画像をダウンロードする
コントロールネットの入力に利用する画像を取得し、後述するフォルダの「inputs」フォルダに格納してください。
(今回の実験で利用している画像に関しては、AI生成画像ですが、版権画像なのでここで提示するのは、変換後の深度画像のみとします。)
元の画像はANIMAGINE XL 3.1の公式が提供しているチュートリアルで使われているプロンプトで作成した画像の一つです。(おそらくseed42-47のあたり)
解説
下記の通り、解説を行います。
まずは上記のリポジトリをcloneしてください。
git clone https://github.com/personabb/colab_AI_sample.git
その後、cloneしたフォルダ「colab_AI_sample」をマイドライブの適当な場所においてください。
ディレクトリ構造
Google Driveのディレクトリ構造は下記を想定します。
MyDrive/
└ colab_AI_sample/
└ colab_SDXLControlNet_sample_forkDiffuser/
├ configs/
| └ config.ini
├ inputs/
| | refer.webp
| └ DreamyvibesartstyleSDXL.safetensors
├ outputs/
├ module/
| └ module_sd3c.py
└ SDXLControlNet_sample.ipynb
-
colab_AI_sample
フォルダは適当です。なんでも良いです。1階層である必要はなく下記のように複数階層になっていても良いです。MyDrive/hogehoge/spamspam/hogespam/colab_AI_sample
-
outputs
フォルダには、生成後の画像が格納されます。最初は空です。- 連続して生成を行う場合、過去の生成内容を上書きするため、ダウンロードするか、名前を変えておくことをオススメします。
-
inputs
フォルダには、ControlNetで利用する参照画像を格納しています。詳細は後述します。- 加えて、先ほどダウンロードしたLoRAモデルも格納します
- 名前に空白が入っているのが気持ち悪かったのでリネームしています。
- 加えて、先ほどダウンロードしたLoRAモデルも格納します
使い方解説
SDXLControlNet_sample.ipynb
をGoogle Colabratoryアプリで開いてください。
ファイルを右クリックすると「アプリで開く」という項目が表示されるため、そこからGoogle Colabratoryアプリを選択してください。
もし、ない場合は、「アプリを追加」からアプリストアに行き、「Google Colabratory」で検索してインストールをしてください。
Google Colabratoryアプリで開いたら、SDXLControlNet_sample.ipynb
のメモを参考にして、一番上のセルから順番に実行していけば、問題なく最後まで動作して、画像生成をすることができると思います。
また、最後まで実行後、パラメータを変更して再度実行する場合は、「ランタイム」→「セッションを再起動して全て実行する」をクリックしてください。
コードの解説を後回しに、とにかく実験をしたい方は、実験の章まで飛ばしてください
コード解説
主に、重要なSDXLControlNet_sample.ipynb
とmodule/module_sdc.py
について解説します。
SDXLControlNet_sample.ipynb
該当のコードは下記になります。
基本的にはPart10と同じですが、パッケージインストール部分が異なります。
具体的には下記のようになっています。
1セル目
%rm -r /content/diffusers-preview_latents
%cd /content/
!git clone https://github.com/personabb/diffusers-preview_latents.git
import sys
sys.path
sys.path.append('/content/diffusers-preview_latents/src')
ここでは、「diffusers-preview_latents」というリポジトリをクローンしてきています。
本リポジトリは、通常のDiffusersライブラリをForkして作成しており、通常では取得できない値を取得できるように改造しています。
大元のDiffusersライブラリのver0.29.0をForkしています。
詳細は後述します。
その上で、クローンしたリポジトリを環境変数に追加することで、別のスクリプトからDiffusersモジュールとして読み込めるようにしています
また、2セル目のパッケージインストールに関しては、これまではDiffusersが入っていたと思いますが、今回はDiffusersを利用しないので、取り除いています。
加えて、5セル目の設定ファイルの部分も多少変わっています。変わっているのは下記の部分です。
use_dpm_solver = False
save_latent_simple = False
save_latent_overstep = False
save_latent_approximation = False
save_predict_skip_x0 = False
一つ目において、これまではSamplerに「DPMSolverMultistepScheduler」を利用していたのですが、こちらでは後述する実験において、不都合があったため、use_dpm_solver = False
として、「EulerAncestralDiscreteScheduler」が利用されるように変更しています。
2つ目においては、どの部分をTrueに変更するかで、保存される途中経過が変わります。実験の際に紹介します。
その他のセルは、Part10のものと同様です。
module/module_sdc.py
続いて、SDXLControlNet_sample.ipynb
から読み込まれるモジュールの中身を説明します。
下記にコード全文を示します。
コード全文
from diffusers import DiffusionPipeline, AutoencoderKL, StableDiffusionXLControlNetPipeline, ControlNetModel
from diffusers.utils import load_image
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
import torch
from diffusers.schedulers import DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler
from controlnet_aux.processor import Processor
import os
import configparser
# ファイルの存在チェック用モジュール
import errno
import cv2
from PIL import Image
import time
import numpy as np
class SDXLCconfig:
def __init__(self, config_ini_path = './configs/config.ini'):
# iniファイルの読み込み
self.config_ini = configparser.ConfigParser()
# 指定したiniファイルが存在しない場合、エラー発生
if not os.path.exists(config_ini_path):
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), config_ini_path)
self.config_ini.read(config_ini_path, encoding='utf-8')
SDXLC_items = self.config_ini.items('SDXLC')
self.SDXLC_config_dict = dict(SDXLC_items)
class SDXLC:
def __init__(self,device = None, config_ini_path = './configs/config.ini'):
SDXLC_config = SDXLCconfig(config_ini_path = config_ini_path)
config_dict = SDXLC_config.SDXLC_config_dict
if device is not None:
self.device = device
else:
device = config_dict["device"]
self.device = "cuda" if torch.cuda.is_available() else "cpu"
if device != "auto":
self.device = device
self.last_latents = None
self.last_step = -1
self.last_timestep = 1000
self.n_steps = int(config_dict["n_steps"])
if not config_dict["high_noise_frac"] == "None":
self.high_noise_frac = float(config_dict["high_noise_frac"])
else:
self.high_noise_frac = None
self.seed = int(config_dict["seed"])
self.generator = torch.Generator(device=self.device).manual_seed(self.seed)
self.controlnet_path = config_dict["controlnet_path"]
self.control_mode = config_dict["control_mode"]
if self.control_mode == "None":
self.control_mode = None
self.vae_model_path = config_dict["vae_model_path"]
self.VAE_FLAG = True
if self.vae_model_path == "None":
self.vae_model_path = None
self.VAE_FLAG = False
self.base_model_path = config_dict["base_model_path"]
self.REFINER_FLAG = True
self.refiner_model_path = config_dict["refiner_model_path"]
if self.refiner_model_path == "None":
self.refiner_model_path = None
self.REFINER_FLAG = False
self.LORA_FLAG = True
self.lora_weight_path = config_dict["lora_weight_path"]
if self.lora_weight_path == "None":
self.lora_weight_path = None
self.LORA_FLAG = False
self.lora_scale = float(config_dict["lora_scale"])
self.use_dpm_solver = config_dict["use_dpm_solver"]
if self.use_dpm_solver == "True":
self.use_dpm_solver = True
else:
self.use_dpm_solver = False
self.use_karras_sigmas = config_dict["use_karras_sigmas"]
if self.use_karras_sigmas == "True":
self.use_karras_sigmas = True
else:
self.use_karras_sigmas = False
self.scheduler_algorithm_type = config_dict["scheduler_algorithm_type"]
if config_dict["solver_order"] != "None":
self.solver_order = int(config_dict["solver_order"])
else:
self.solver_order = None
self.cfg_scale = float(config_dict["cfg_scale"])
self.width = int(config_dict["width"])
self.height = int(config_dict["height"])
self.output_type = config_dict["output_type"]
self.aesthetic_score = float(config_dict["aesthetic_score"])
self.negative_aesthetic_score = float(config_dict["negative_aesthetic_score"])
self.save_latent_simple = config_dict["save_latent_simple"]
if self.save_latent_simple == "True":
self.save_latent_simple = True
print("use callback save_latent_simple")
else:
self.save_latent_simple = False
self.save_latent_overstep = config_dict["save_latent_overstep"]
if self.save_latent_overstep == "True":
self.save_latent_overstep = True
print("use callback save_latent_overstep")
else:
self.save_latent_overstep = False
self.save_latent_approximation = config_dict["save_latent_approximation"]
if self.save_latent_approximation == "True":
self.save_latent_approximation = True
print("use callback save_latent_approximation")
else:
self.save_latent_approximation = False
self.save_predict_skip_x0 = config_dict["save_predict_skip_x0"]
if self.save_predict_skip_x0 == "True":
self.save_predict_skip_x0 = True
print("use callback save_predict_skip_x0")
else:
self.save_predict_skip_x0 = False
self.use_callback = False
if self.save_latent_simple or self.save_latent_overstep or self.save_latent_approximation or self.save_predict_skip_x0:
self.use_callback = True
if self.save_predict_skip_x0:
if self.save_latent_simple or self.save_latent_overstep:
raise ValueError("save_predict_skip_x0 and (save_latent_simple or save_latent_overstep) cannot be set at the same time")
if self.use_dpm_solver:
raise ValueError("save_predict_skip_x0 and use_dpm_solver cannot be set at the same time")
else:
if self.save_latent_simple and self.save_latent_overstep:
raise ValueError("save_latent_simple and save_latent_overstep cannot be set at the same time")
self.base , self.refiner = self.preprepare_model()
def preprepare_model(self):
controlnet = ControlNetModel.from_pretrained(
self.controlnet_path,
use_safetensors=True,
torch_dtype=torch.float16)
if self.VAE_FLAG:
vae = AutoencoderKL.from_pretrained(
self.vae_model_path,
torch_dtype=torch.float16)
base = StableDiffusionXLControlNetPipeline.from_pretrained(
self.base_model_path,
controlnet=controlnet,
vae=vae,
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True
)
base.to(self.device)
if self.REFINER_FLAG:
refiner = DiffusionPipeline.from_pretrained(
self.refiner_model_path,
text_encoder_2=base.text_encoder_2,
vae=vae,
requires_aesthetics_score=True,
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True
)
refiner.enable_model_cpu_offload()
else:
refiner = None
else:
base = StableDiffusionXLControlNetPipeline.from_pretrained(
self.base_model_path,
controlnet=controlnet,
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True
)
base.to(self.device, torch.float16)
if self.REFINER_FLAG:
refiner = DiffusionPipeline.from_pretrained(
self.refiner_model_path,
text_encoder_2=base.text_encoder_2,
requires_aesthetics_score=True,
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True
)
refiner.enable_model_cpu_offload()
else:
refiner = None
if self.LORA_FLAG:
base.load_lora_weights(self.lora_weight_path)
if self.use_dpm_solver:
if self.solver_order is not None:
base.scheduler = DPMSolverMultistepScheduler.from_config(
base.scheduler.config,
use_karras_sigmas=self.use_karras_sigmas,
Algorithm_type =self.scheduler_algorithm_type,
solver_order=self.solver_order,
)
else:
base.scheduler = DPMSolverMultistepScheduler.from_config(
base.scheduler.config,
use_karras_sigmas=self.use_karras_sigmas,
Algorithm_type =self.scheduler_algorithm_type,
)
else:
base.scheduler = EulerAncestralDiscreteScheduler.from_config(base.scheduler.config)
return base, refiner
def prepare_referimage(self,input_refer_image_path,output_refer_image_path, low_threshold = 100, high_threshold = 200):
mode = None
if self.control_mode is not None:
mode = self.control_mode
else:
raise ValueError("control_mode is not set")
def prepare_openpose(input_refer_image_path,output_refer_image_path, mode):
# 初期画像の準備
init_image = load_image(input_refer_image_path)
init_image = init_image.resize((self.width, self.height))
processor = Processor(mode)
processed_image = processor(init_image, to_pil=True)
processed_image.save(output_refer_image_path)
def prepare_canny(input_refer_image_path,output_refer_image_path, low_threshold = 100, high_threshold = 200):
init_image = load_image(input_refer_image_path)
init_image = init_image.resize((self.width, self.height))
# コントロールイメージを作成するメソッド
def make_canny_condition(image, low_threshold = 100, high_threshold = 200):
image = np.array(image)
image = cv2.Canny(image, low_threshold, high_threshold)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
return Image.fromarray(image)
control_image = make_canny_condition(init_image, low_threshold, high_threshold)
control_image.save(output_refer_image_path)
def prepare_depthmap(input_refer_image_path,output_refer_image_path):
# 初期画像の準備
init_image = load_image(input_refer_image_path)
init_image = init_image.resize((self.width, self.height))
processor = Processor("depth_midas")
depth_image = processor(init_image, to_pil=True)
depth_image.save(output_refer_image_path)
def prepare_zoe_depthmap(input_refer_image_path,output_refer_image_path):
torch.hub.help(
"intel-isl/MiDaS",
"DPT_BEiT_L_384",
force_reload=True
)
model_zoe_n = torch.hub.load(
"isl-org/ZoeDepth",
"ZoeD_NK",
pretrained=True
).to("cuda")
init_image = load_image(input_refer_image_path)
init_image = init_image.resize((self.width, self.height))
depth_numpy = model_zoe_n.infer_pil(init_image) # return: numpy.ndarray
from zoedepth.utils.misc import colorize
colored = colorize(depth_numpy) # numpy.ndarray => numpy.ndarray
# gamma correction
img = colored / 255
img = np.power(img, 2.2)
img = (img * 255).astype(np.uint8)
Image.fromarray(img).save(output_refer_image_path)
if "openpose" in mode:
prepare_openpose(input_refer_image_path,output_refer_image_path, mode)
elif mode == "canny":
prepare_canny(input_refer_image_path,output_refer_image_path, low_threshold = low_threshold, high_threshold = high_threshold)
elif mode == "depth":
prepare_depthmap(input_refer_image_path,output_refer_image_path)
elif mode == "zoe_depth":
prepare_zoe_depthmap(input_refer_image_path,output_refer_image_path)
elif mode == "tile" or mode == "scribble":
init_image = load_image(input_refer_image_path)
init_image.save(output_refer_image_path)
else:
raise ValueError("control_mode is not set")
def generate_image(self, prompt, neg_prompt, image_path, seed = None, controlnet_conditioning_scale = 1.0):
def decode_tensors(pipe, step, timestep, callback_kwargs):
if self.save_latent_simple or self.save_predict_skip_x0:
callback_kwargs = decode_tensors_simple(pipe, step, timestep, callback_kwargs)
elif self.save_latent_overstep:
callback_kwargs = decode_tensors_residual(pipe, step, timestep, callback_kwargs)
else:
raise ValueError("self.save_predict_skip_x0 or save_latent_simple or save_latent_overstep must be set or 'save_latent_approximation = False'")
return callback_kwargs
def decode_tensors_simple(pipe, step, timestep, callback_kwargs):
latents = callback_kwargs["latents"]
skip_x0 = callback_kwargs["skip_x0"]
imege = None
prefix = None
if not self.save_predict_skip_x0:
prefix = "latents"
if self.save_latent_simple and not self.save_latent_approximation:
image = latents_to_rgb_vae(latents,pipe)
elif self.save_latent_approximation:
image = latents_to_rgb_approximation(latents,pipe)
else:
raise ValueError("save_latent_simple or save_latent_approximation is not set")
else:
prefix = "predicted_x0"
image = latents_to_rgb_vae(skip_x0,pipe)
gettime = time.time()
formatted_time_human_readable = time.strftime("%Y%m%d_%H%M%S", time.localtime(gettime))
image.save(f"./outputs/{prefix}_{formatted_time_human_readable}_{step}_{timestep}.png")
return callback_kwargs
def decode_tensors_residual(pipe, step, timestep, callback_kwargs):
latents = callback_kwargs["latents"]
if step > 0:
residual = latents - self.last_latents
goal = self.last_latents + residual * ((self.last_timestep) / (self.last_timestep - timestep))
#print( ((self.last_timestep) / (self.last_timestep - timestep)))
else:
goal = latents
if self.save_latent_overstep and not self.save_latent_approximation:
image = latents_to_rgb_vae(goal,pipe)
elif self.save_latent_approximation:
image = latents_to_rgb_approximation(goal,pipe)
else:
raise ValueError("save_latent_simple or save_latent_approximation is not set")
gettime = time.time()
formatted_time_human_readable = time.strftime("%Y%m%d_%H%M%S", time.localtime(gettime))
image.save(f"./outputs/latent_{formatted_time_human_readable}_{step}_{timestep}.png")
self.last_latents = latents
self.last_step = step
self.last_timestep = timestep
if timestep == 0:
self.last_latents = None
self.last_step = -1
self.last_timestep = 100
return callback_kwargs
def latents_to_rgb_vae(latents,pipe):
pipe.upcast_vae()
latents = latents.to(next(iter(pipe.vae.post_quant_conv.parameters())).dtype)
images = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
images = pipe.image_processor.postprocess(images, output_type='pil')
pipe.vae.to(dtype=torch.float16)
return StableDiffusionXLPipelineOutput(images=images).images[0]
def latents_to_rgb_approximation(latents, pipe):
weights = (
(60, -60, 25, -70),
(60, -5, 15, -50),
(60, 10, -5, -35)
)
weights_tensor = torch.t(torch.tensor(weights, dtype=latents.dtype).to(latents.device))
biases_tensor = torch.tensor((150, 140, 130), dtype=latents.dtype).to(latents.device)
rgb_tensor = torch.einsum("...lxy,lr -> ...rxy", latents, weights_tensor) + biases_tensor.unsqueeze(-1).unsqueeze(-1)
image_array = rgb_tensor.clamp(0, 255)[0].byte().cpu().numpy()
image_array = image_array.transpose(1, 2, 0) # Change the order of dimensions
return Image.fromarray(image_array)
if seed is not None:
self.generator = torch.Generator(device=self.device).manual_seed(seed)
control_image = load_image(image_path)
image = None
if self.use_callback:
if self.LORA_FLAG:
if self.REFINER_FLAG:
image = self.base(
prompt=prompt,
negative_prompt=neg_prompt,
image=control_image,
cfg_scale=self.cfg_scale,
controlnet_conditioning_scale=controlnet_conditioning_scale,
num_inference_steps=self.n_steps,
denoising_end=self.high_noise_frac,
output_type="latent",
width = self.width,
height = self.height,
generator=self.generator,
cross_attention_kwargs={"scale": self.lora_scale},
callback_on_step_end=decode_tensors,
callback_on_step_end_tensor_inputs=["latents", "skip_x0"],
).images[0]
image = self.refiner(
prompt=prompt,
negative_prompt=neg_prompt,
cfg_scale=self.cfg_scale,
aesthetic_score = self.aesthetic_score,
negative_aesthetic_score = self.negative_aesthetic_score,
num_inference_steps=self.n_steps,
denoising_start=self.high_noise_frac,
callback_on_step_end=decode_tensors,
callback_on_step_end_tensor_inputs=["latents", "skip_x0"],
image=image[None, :]
).images[0]
#refiner を利用しない場合
else:
image = self.base(
prompt=prompt,
negative_prompt=neg_prompt,
image=control_image,
cfg_scale=self.cfg_scale,
controlnet_conditioning_scale=controlnet_conditioning_scale,
num_inference_steps=self.n_steps,
denoising_end=self.high_noise_frac,
output_type=self.output_type,
width = self.width,
height = self.height,
generator=self.generator,
callback_on_step_end=decode_tensors,
callback_on_step_end_tensor_inputs=["latents", "skip_x0"],
cross_attention_kwargs={"scale": self.lora_scale},
).images[0]
#LORAを利用しない場合
else:
if self.REFINER_FLAG:
image = self.base(
prompt=prompt,
negative_prompt=neg_prompt,
image=control_image,
cfg_scale=self.cfg_scale,
controlnet_conditioning_scale=controlnet_conditioning_scale,
num_inference_steps=self.n_steps,
denoising_end=self.high_noise_frac,
output_type="latent",
width = self.width,
height = self.height,
callback_on_step_end=decode_tensors,
callback_on_step_end_tensor_inputs=["latents", "skip_x0"],
generator=self.generator
).images[0]
image = self.refiner(
prompt=prompt,
negative_prompt=neg_prompt,
cfg_scale=self.cfg_scale,
aesthetic_score = self.aesthetic_score,
negative_aesthetic_score = self.negative_aesthetic_score,
num_inference_steps=self.n_steps,
denoising_start=self.high_noise_frac,
callback_on_step_end=decode_tensors,
callback_on_step_end_tensor_inputs=["latents", "skip_x0"],
image=image[None, :]
).images[0]
#refiner を利用しない場合
else:
image = self.base(
prompt=prompt,
negative_prompt=neg_prompt,
image=control_image,
cfg_scale=self.cfg_scale,
controlnet_conditioning_scale=controlnet_conditioning_scale,
num_inference_steps=self.n_steps,
denoising_end=self.high_noise_frac,
output_type=self.output_type,
width = self.width,
height = self.height,
callback_on_step_end=decode_tensors,
callback_on_step_end_tensor_inputs=["latents", "skip_x0"],
generator=self.generator
).images[0]
#latentを保存しない場合
else:
if self.LORA_FLAG:
if self.REFINER_FLAG:
image = self.base(
prompt=prompt,
negative_prompt=neg_prompt,
image=control_image,
cfg_scale=self.cfg_scale,
controlnet_conditioning_scale=controlnet_conditioning_scale,
num_inference_steps=self.n_steps,
denoising_end=self.high_noise_frac,
output_type="latent",
width = self.width,
height = self.height,
generator=self.generator,
cross_attention_kwargs={"scale": self.lora_scale},
).images[0]
image = self.refiner(
prompt=prompt,
negative_prompt=neg_prompt,
cfg_scale=self.cfg_scale,
aesthetic_score = self.aesthetic_score,
negative_aesthetic_score = self.negative_aesthetic_score,
num_inference_steps=self.n_steps,
denoising_start=self.high_noise_frac,
image=image[None, :]
).images[0]
# refiner を利用しない場合
else:
image = self.base(
prompt=prompt,
negative_prompt=neg_prompt,
image=control_image,
cfg_scale=self.cfg_scale,
controlnet_conditioning_scale=controlnet_conditioning_scale,
num_inference_steps=self.n_steps,
denoising_end=self.high_noise_frac,
output_type=self.output_type,
width = self.width,
height = self.height,
generator=self.generator,
cross_attention_kwargs={"scale": self.lora_scale},
).images[0]
# LORAを利用しない場合
else:
if self.REFINER_FLAG:
image = self.base(
prompt=prompt,
negative_prompt=neg_prompt,
image=control_image,
cfg_scale=self.cfg_scale,
controlnet_conditioning_scale=controlnet_conditioning_scale,
num_inference_steps=self.n_steps,
denoising_end=self.high_noise_frac,
output_type="latent",
width = self.width,
height = self.height,
generator=self.generator
).images[0]
image = self.refiner(
prompt=prompt,
negative_prompt=neg_prompt,
cfg_scale=self.cfg_scale,
aesthetic_score = self.aesthetic_score,
negative_aesthetic_score = self.negative_aesthetic_score,
num_inference_steps=self.n_steps,
denoising_start=self.high_noise_frac,
image=image[None, :]
).images[0]
# refiner を利用しない場合
else:
image = self.base(
prompt=prompt,
negative_prompt=neg_prompt,
image=control_image,
cfg_scale=self.cfg_scale,
controlnet_conditioning_scale=controlnet_conditioning_scale,
num_inference_steps=self.n_steps,
denoising_end=self.high_noise_frac,
output_type=self.output_type,
width = self.width,
height = self.height,
generator=self.generator
).images[0]
return image
基本的にはpart10の内容と同一ではあるが、下記の部分が異なるため説明します。
def generate_image(self, prompt, neg_prompt, image_path, seed = None, controlnet_conditioning_scale = 1.0):
・・・・・・・
image = self.base(
prompt=prompt,
negative_prompt=neg_prompt,
image=control_image,
cfg_scale=self.cfg_scale,
controlnet_conditioning_scale=controlnet_conditioning_scale,
num_inference_steps=self.n_steps,
denoising_end=self.high_noise_frac,
output_type=self.output_type,
width = self.width,
height = self.height,
generator=self.generator,
callback_on_step_end=decode_tensors,
callback_on_step_end_tensor_inputs=["latents", "skip_x0"],
cross_attention_kwargs={"scale": self.lora_scale},
).images[0]
・・・・・・・
画像を生成する際に、上記のような形で生成しますが、特に、callback_on_step_end_tensor_inputs=["latents", "skip_x0"],
において"skip_x0"
が追加されています。
これは、Samplerが各Stepごとに予測するクリーン画像をコールバック関数で受け取るための引数になります。しかしながら、通常のDiffusersモジュールでは、この潜在表現は受け取るとができないので、Forkして改造しています。
それが下記のリポジトリです。
変更箇所は下記を見ればわかります。わかりやすさ重視のため最低限の変更に抑えています。
まず変更点について解説し、その後なぜそこを変更したのかについて説明します。
まず一個目の変更点は下記です。
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
"negative_prompt_embeds",
"add_text_embeds",
"add_time_ids",
"negative_pooled_prompt_embeds",
"negative_add_time_ids",
+ "skip_x0",
]
上記はcallback関数に渡すことが可能な変数のリストになります。ここに書かれていない変数をcallback関数で受け取ろうとすると、StableDiffusionXLControlNetPipeline
クラスの__call__
メソッド内で呼ばれるcheck_inputs
メソッドでエラーになります。
続いての変更点は下記です。
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+ latent_all = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=True)
+ latents = latent_all.prev_sample
+ skip_x0 = None
+ if hasattr(latent_all, 'pred_original_sample'):
+ skip_x0 = latent_all.pred_original_sample
self.scheduler.step
は本パイプラインにて利用しているSamplerクラスにて定義されているstep
メソッドを読んでいます。
step
メソッドはreturn_dict
引数の真偽により、出力が変わります。
元のDiffuserの通り、return_dict=False
とすると、各ステップごとのノイズ混じりの潜在表現のみを取得できます。
一方で、return_dict=True
とすると、各ステップごとのノイズ混じりの潜在表現に加えて、各ステップにてSamplerが予測したクリーン画像(DPMSolverMultistepScheduler
は取得できません。一方で、SamplerとしてEulerAncestralDiscreteScheduler
を利用している場合は、取得可能なので今回はこちらを利用します。
では、実際にstep
メソッドについて中身を見ていきます。
該当部分は下記です
if not return_dict:
return (prev_sample,)
return EulerAncestralDiscreteSchedulerOutput(
prev_sample=prev_sample, pred_original_sample=pred_original_sample
)
上記で解説した通り、引数のreturn_dictがTrueの場合は、returnとしてEulerAncestralDiscreteSchedulerOutput
が返されます。
EulerAncestralDiscreteSchedulerOutput
は二つのパラメータを持ちます。
deeplにて翻訳した文章が下記です。
prev_sample (torch.Tensor of shape (batch_size, num_channels, height, width) for images) - 直前のタイムステップで計算されたサンプル(x_{t-1}).
pred_original_sample (torch.Tensor of shape (batch_size, num_channels, height, width) for images) - 現在のタイムステップのモデル出力に基づく,ノイズ除去予測サンプル(x_{0}).
上記の通り、EulerAncestralDiscreteSchedulerOutput
クラスのpred_original_sample
属性を取得することができれば、クリーン画像を取得することができます。
そのために下記のようにDiffuserに変更を加えたわけです。
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+ latent_all = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=True)
+ latents = latent_all.prev_sample
+ skip_x0 = None
+ if hasattr(latent_all, 'pred_original_sample'):
+ skip_x0 = latent_all.pred_original_sample
DPMSolverMultistepScheduler
ちなみにですがDPMSolverMultistepScheduler
の公式のドキュメントを見ても分かるとおり、DPMSolverMultistepScheduler
の出力で利用されているSchedulerOutput
クラスにはprev_sample
の属性はあるが、pred_original_sample
の属性はないため、各ステップごとのクリーン画像を取得することができないため、設定でuse_dpm_solver = False
を用意しています。
逆に言えば、出力のクラスを確認して、pred_original_sample
属性が存在するSamplerであれば、EulerAncestralDiscreteScheduler
でなくても、同様に各ステップごとのクリーン画像を取得できます。
今回、EulerAncestralDiscreteScheduler
を利用したのは、使っているモデルanimagine-xl-3.1の「Recommended settings」の章にて、使用を推奨されていたSamplerだからです。
it’s recommended to use a lower classifier-free guidance (CFG Scale) of around 5-7, sampling steps below 30, and to use Euler Ancestral (Euler a) as a sampler.
続いて、潜在表現を画像化する部分に関して解説します。
まずは該当部分の全文を表示します。
潜在表現から画像を再構成するコールバック関数
def decode_tensors(pipe, step, timestep, callback_kwargs):
if self.save_latent_simple or self.save_predict_skip_x0:
callback_kwargs = decode_tensors_simple(pipe, step, timestep, callback_kwargs)
elif self.save_latent_overstep:
callback_kwargs = decode_tensors_residual(pipe, step, timestep, callback_kwargs)
else:
raise ValueError("self.save_predict_skip_x0 or save_latent_simple or save_latent_overstep must be set or 'save_latent_approximation = False'")
return callback_kwargs
def decode_tensors_simple(pipe, step, timestep, callback_kwargs):
latents = callback_kwargs["latents"]
skip_x0 = callback_kwargs["skip_x0"]
imege = None
prefix = None
if not self.save_predict_skip_x0:
prefix = "latents"
if self.save_latent_simple and not self.save_latent_approximation:
image = latents_to_rgb_vae(latents,pipe)
elif self.save_latent_approximation:
image = latents_to_rgb_approximation(latents,pipe)
else:
raise ValueError("save_latent_simple or save_latent_approximation is not set")
else:
prefix = "predicted_x0"
image = latents_to_rgb_vae(skip_x0,pipe)
gettime = time.time()
formatted_time_human_readable = time.strftime("%Y%m%d_%H%M%S", time.localtime(gettime))
image.save(f"./outputs/{prefix}_{formatted_time_human_readable}_{step}_{timestep}.png")
return callback_kwargs
def decode_tensors_residual(pipe, step, timestep, callback_kwargs):
latents = callback_kwargs["latents"]
if step > 0:
residual = latents - self.last_latents
goal = self.last_latents + residual * ((self.last_timestep) / (self.last_timestep - timestep))
#print( ((self.last_timestep) / (self.last_timestep - timestep)))
else:
goal = latents
if self.save_latent_overstep and not self.save_latent_approximation:
image = latents_to_rgb_vae(goal,pipe)
elif self.save_latent_approximation:
image = latents_to_rgb_approximation(goal,pipe)
else:
raise ValueError("save_latent_simple or save_latent_approximation is not set")
gettime = time.time()
formatted_time_human_readable = time.strftime("%Y%m%d_%H%M%S", time.localtime(gettime))
image.save(f"./outputs/latent_{formatted_time_human_readable}_{step}_{timestep}.png")
self.last_latents = latents
self.last_step = step
self.last_timestep = timestep
if timestep == 0:
self.last_latents = None
self.last_step = -1
self.last_timestep = 100
return callback_kwargs
def latents_to_rgb_vae(latents,pipe):
pipe.upcast_vae()
latents = latents.to(next(iter(pipe.vae.post_quant_conv.parameters())).dtype)
images = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
images = pipe.image_processor.postprocess(images, output_type='pil')
pipe.vae.to(dtype=torch.float16)
return StableDiffusionXLPipelineOutput(images=images).images[0]
def latents_to_rgb_approximation(latents, pipe):
weights = (
(60, -60, 25, -70),
(60, -5, 15, -50),
(60, 10, -5, -35)
)
weights_tensor = torch.t(torch.tensor(weights, dtype=latents.dtype).to(latents.device))
biases_tensor = torch.tensor((150, 140, 130), dtype=latents.dtype).to(latents.device)
rgb_tensor = torch.einsum("...lxy,lr -> ...rxy", latents, weights_tensor) + biases_tensor.unsqueeze(-1).unsqueeze(-1)
image_array = rgb_tensor.clamp(0, 255)[0].byte().cpu().numpy()
image_array = image_array.transpose(1, 2, 0) # Change the order of dimensions
return Image.fromarray(image_array)
一つずつ説明します。
まず、各ステップごとにコールバック関数として呼ばれる関数は、decode_tensors
関数になります。この関数は、パイプラインの__call__
メソッドの引数のcallback_on_step_end=decode_tensors,
として指定しています。
decode_tensors
関数は下記のような関数です。
def decode_tensors(pipe, step, timestep, callback_kwargs):
if self.save_latent_simple or self.save_predict_skip_x0:
callback_kwargs = decode_tensors_simple(pipe, step, timestep, callback_kwargs)
elif self.save_latent_overstep:
callback_kwargs = decode_tensors_residual(pipe, step, timestep, callback_kwargs)
else:
raise ValueError("self.save_predict_skip_x0 or save_latent_simple or save_latent_overstep must be set or 'save_latent_approximation = False'")
return callback_kwargs
コールバック関数は(pipe, step, timestep, callback_kwargs)
を引数とし、returnとしてcallback_kwargs
を返す必要がある。このcallback_kwargs
はコールバック関数の引数と同じ形である必要がある。具体的には辞書として、下記の名前がキーとなっているものが必要です。
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
"negative_prompt_embeds",
"add_text_embeds",
"add_time_ids",
"negative_pooled_prompt_embeds",
"negative_add_time_ids",
"skip_x0",
]
さらに例えば、callback_kwargs["latents"] = 2 * latents
のように、値を変更して、returnすることで、元の画像生成の処理で利用しているlatentsの値も操作することができます。
証拠
callback関数に関連する箇所はここ
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
negative_pooled_prompt_embeds = callback_outputs.pop(
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
)
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
実際にコールバック関数が動くのは
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
であり、その出力をpopする形で、latents
やprompt_embeds
などを各ステップごとに変更することができるように設計されているため、コールバック関数を駆使することで色々なことができそうですね。
また、_callback_tensor_inputs
に追加することで、使われているどんな変数でも取得することができます。なぜなら下記で、__call__
メソッドのスコープ内の変数を全て取得できるようになっているからです。
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
では、元のコールバック関数の説明に戻ります。
コールバック関数では、設定ファイルの設定に応じて、二つの関数のうちどちらかが起動するようになっています。
if self.save_latent_simple or self.save_predict_skip_x0:
callback_kwargs = decode_tensors_simple(pipe, step, timestep, callback_kwargs)
elif self.save_latent_overstep:
callback_kwargs = decode_tensors_residual(pipe, step, timestep, callback_kwargs)
すなわち、save_latent_simple
かsave_predict_skip_x0
がTrueの場合はdecode_tensors_simple
関数が起動し、save_latent_overstep
がTrueの場合はdecode_tensors_residual
が起動します。
decode_tensors_simple
関数は下記のように定義されます。
def decode_tensors_simple(pipe, step, timestep, callback_kwargs):
latents = callback_kwargs["latents"]
skip_x0 = callback_kwargs["skip_x0"]
imege = None
prefix = None
if not self.save_predict_skip_x0:
prefix = "latents"
if self.save_latent_simple and not self.save_latent_approximation:
image = latents_to_rgb_vae(latents,pipe)
elif self.save_latent_approximation:
image = latents_to_rgb_approximation(latents,pipe)
else:
raise ValueError("save_latent_simple or save_latent_approximation is not set")
else:
prefix = "predicted_x0"
image = latents_to_rgb_vae(skip_x0,pipe)
gettime = time.time()
formatted_time_human_readable = time.strftime("%Y%m%d_%H%M%S", time.localtime(gettime))
image.save(f"./outputs/{prefix}_{formatted_time_human_readable}_{step}_{timestep}.png")
return callback_kwargs
この中でも同様に設定ファイルの設定に応じて、使う関数を変更させて潜在変数の処理を行い、その上で、生成された画像を「outputs」フォルダに保存しています。
重要なのは、
image = latents_to_rgb_vae(latents,pipe)
と
image = latents_to_rgb_approximation(latents,pipe)
と
image = latents_to_rgb_vae(skip_x0,pipe)
です。
latents_to_rgb_vae(latents,pipe)
は各ステップのノイズ混じりの潜在表現を受けとり、VAEで画像を再構成しています。
latents_to_rgb_vae(skip_x0,pipe)
も入力が、各ステップでの予測されたクリーン画像になっているだけで本質は同じです。
latents_to_rgb_vae
関数は下記のように定義されており、基本的にはDiffuserモジュールの書き方を踏襲して作っています。
def latents_to_rgb_vae(latents,pipe):
pipe.upcast_vae()
latents = latents.to(next(iter(pipe.vae.post_quant_conv.parameters())).dtype)
images = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
images = pipe.image_processor.postprocess(images, output_type='pil')
pipe.vae.to(dtype=torch.float16)
return StableDiffusionXLPipelineOutput(images=images).images[0]
違うのはpipe.vae.to(dtype=torch.float16)
の部分です。
VAEに潜在表現を通す前にupcastしてfloat16からfloat32に型変換をする必要があります。
VAEはコールバック関数で利用するものも、パイプラインの__call__
メソッド内で利用するものも同じものを利用するため、コールバック関数内で変更した型は元に戻す必要があるので、加えています。
ここで戻さないとlatentsの型とVAEの型が合わないため、最終ステップでの画像を生成する際にエラーが発生します。
エラーの理由
発生するエラーは下記になります。
RuntimeError: Input type (c10::Half) and bias type (float) should be the same
VAEの型がfloat32になっていると下記の部分でneeds_upcasting
がFalseになってしまいます。
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
その場合、次の部分の処理が行われないため、VAEはfloat32なのに、latentsはfloat16のまま処理を行うことになります。
if needs_upcasting:
self.upcast_vae()
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
結果、型が合わないエラーが発生します。
続いて、latents_to_rgb_approximation(latents,pipe)
です。
こちらも各ステップのノイズ混じり潜在表現から画像を再構成する関数ですが、再構成の方法が異なります。
これまでの方法はVAEのDecoderに通すことで再構成していますが、今回は潜在表現に対して線形処理を行うことで、画像を線形近似して表示しています。
わかっているのはこの処理のように線形近似すると、潜在表現から画像っぽいものが再構成されるということです。この調査をされた方は本当にすごいですね・・・
以上で、一旦decode_tensors_simple
関数の説明は終わります。
次に。decode_tensors_residual
関数についてです。
def decode_tensors_residual(pipe, step, timestep, callback_kwargs):
latents = callback_kwargs["latents"]
if step > 0:
residual = latents - self.last_latents
goal = self.last_latents + residual * ((self.last_timestep) / (self.last_timestep - timestep))
#print( ((self.last_timestep) / (self.last_timestep - timestep)))
else:
goal = latents
if self.save_latent_overstep and not self.save_latent_approximation:
image = latents_to_rgb_vae(goal,pipe)
elif self.save_latent_approximation:
image = latents_to_rgb_approximation(goal,pipe)
else:
raise ValueError("save_latent_simple or save_latent_approximation is not set")
gettime = time.time()
formatted_time_human_readable = time.strftime("%Y%m%d_%H%M%S", time.localtime(gettime))
image.save(f"./outputs/latent_{formatted_time_human_readable}_{step}_{timestep}.png")
self.last_latents = latents
self.last_step = step
self.last_timestep = timestep
if timestep == 0:
self.last_latents = None
self.last_step = -1
self.last_timestep = 100
return callback_kwargs
こちらに関しても基本的にはdecode_tensors_simple
関数と同じですが、下記部分だけ異なります。
if step > 0:
residual = latents - self.last_latents
goal = self.last_latents + residual * ((self.last_timestep) / (self.last_timestep - timestep))
else:
goal = latents
ここで実施したいことは、今回の潜在表現と前回の潜在表現の差分residual
には、今回のstepにて得られた入力に対する勾配によって更新が行われた量が格納されることになります。
その変化量をより大きくして更新することで、早い段階から生成画像を確認できるのではないかと考えた次第です。
どの程度変化量を大きくするかというのは下記の式です
((self.last_timestep) / (self.last_timestep - timestep))
これは今回の更新で進んだtimestep数 (self.last_timestep - timestep)
での更新量を残りのtimestep数(self.last_timestep)
(最大1000)倍しています。
以上がmodule/module_sdc.py
におけるpart10の記事との違いになります。
実験結果
ここからは、上記のコードによってGoogle Colabでパラメータを変更して、様々な実験を実施したため、その詳細を記載します。
前提条件
前提として下記の設定を継承します。後述する実験において特に記載がない場合は、この設定が継承されていると考えてください。
5セル目
config_text = """
[SDXLC]
device = auto
n_steps=28
high_noise_frac=None
seed=42
vae_model_path = None
base_model_path = Asahina2K/Animagine-xl-3.1-diffuser-variant-fp16
refiner_model_path = None
controlnet_path = diffusers/controlnet-depth-sdxl-1.0
control_mode = depth
lora_weight_path = ./inputs/DreamyvibesartstyleSDXL.safetensors
lora_scale = 1.0
use_dpm_solver = False
use_karras_sigmas = True
scheduler_algorithm_type = dpmsolver++
solver_order = 2
cfg_scale = 7.0
width = 832
height = 1216
output_type = pil
aesthetic_score = 6
negative_aesthetic_score = 2.5
save_latent_simple = False
save_latent_overstep = False
save_latent_approximation = False
save_predict_skip_x0 = False
"""
with open("configs/config.ini", "w", encoding="utf-8") as f:
f.write(config_text)
上記の設定の通り、LoRAとControlNetのDepthを利用します
6セル目
main_prompt = """
1 girl ,Yellowish-white hair ,short hair ,red small ribbon,red eyes,red hat ,school uniform ,solo ,smile ,upper body ,Anime ,Japanese,best quality,high quality,ultra highres,ultra quality
"""
use_lora = True
if use_lora:
main_prompt += ", Dreamyvibes Artstyle"
negative_prompt="""
nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]
"""
input_refer_image_path = "./inputs/refer.webp"
output_refer_image_path = "./inputs/refer.png"
8セル目
controlnet_conditioning_scale = 0.7
ControNetに入力する画像
参照画像の変換前
(使った画像がバリバリ版権画像だったため、depth画像だけで失礼します。元の画像はANIMAGINE XL 3.1の公式が提供しているチュートリアルで使われているプロンプトで作成した画像の一つです。(おそらくseed42-47のあたり)
参照画像の変換後
ちなみに、上記の深度マップを"./inputs/refer.png"
として保存して、7セル目のsd.prepare_referimage
メソッドを実行せずにコメントアウトすることでも、同様の実験を行うことが可能です。
sd = SDXLC()
#sd.prepare_referimage(input_refer_image_path = input_refer_image_path, output_refer_image_path = output_refer_image_path, low_threshold = 100, high_threshold = 200)
実験1
「生成途中の潜在表現からVAEで再構成した画像の表示」を行います。
設定
前提から、下記部分だけ変更する
save_latent_simple = True
結果
大量の画像が保存されるので、途中画像はgifに変換して表示します。
(3MBの制限に抑えるために、画像はかなり劣化しています。ごめんなさい)
実験2
「生成途中の潜在表現を直接線形近似した画像を表示」を行います。
設定
前提から、下記部分だけ変更する
save_latent_simple = True
save_latent_approximation = True
結果
大量の画像が保存されるので、途中画像はgifに変換して表示します。
SDXLの場合、潜在表現は通常の画像サイズの1/8のサイズになります。VAEを通さずに線形近似をしているだけなので、画像サイズは小さくなります。
実験3
「潜在表現に対する各Stepの更新幅を変更して、VAEで再構成した画像の表示」を行います。
設定
前提から、下記部分だけ変更する
use_dpm_solver = True
save_latent_overstep = True
今回の実験では、なぜかEulerAncestralDiscreteScheduler
のSamplerではうまく機能しなかったので。DPMSolverMultistepScheduler
を利用しました。
結果
大量の画像が保存されるので、途中画像はgifに変換して表示します。
(3MBの制限に抑えるために、画像はかなり劣化しています。ごめんなさい)
実験4
「Samplerが生成途中に推定したクリーン画像の表示」を行います。
設定
前提から、下記部分だけ変更する
save_predict_skip_x0 = True
結果
大量の画像が保存されるので、途中画像はgifに変換して表示します。
(3MBの制限に抑えるために、画像はかなり劣化しています。ごめんなさい)
まとめ
以上、ここまででDiffersを利用して、生成途中の潜在表現から画像を再構成してみました。
みたかぎり、実験4の可視化が一番人間にはみやすい可視化だったかなと思います。
実験4を見ると、最初の大まかな画像が決まっていき、細かい箇所は後から少しずつ決まっていくような動きをすることがわかりました。
また、ControlNetに関連する範囲に関しては早いstepから画像が確定しており、背景はそれに合わせて後から生成されるような動きをしていることがわかりました。
一方で、実験4の可視化手法はDPMSolverMultistepScheduler
などの一部のSamplerでは利用することができません。(利用する方法はあるかもですが、現時点の私の理解力では難しいので、詳しいかたいらっしゃれば教えていただきたいです)
その場合は、実験3の可視化手法も試してみていただけますと嬉しいです。実験4には劣りますが、途中の画像の推移などが見える形になっているかなと思います。
Samplerに応じて、使う可視化手法を変更するのが良さそうかなと思いました。
以上で、終わりです。
ここまで読んでくださり、ありがとうございました。
Discussion