🌍

FLUX.1-devでControlNetを利用したら画像の質が下がったのでImg2Imgを使って改善してみた【Diffusers】

2024/09/13に公開

はじめに

FLUX.1とは、画像生成AIであるかの有名な「Stable Diffusion」の開発に携わったAI研究者が立ちあげた新しいAI開発企業である「Black Forest Labs」により生み出された高性能な画像生成AIモデルです。

先日、Google Colabにて「FLUX.1-dev + ControlNet + LoRA」をDiffusersライブラリで動作させる記事を記載しました。(私は宗教上の理由からWebUIは使えないので、diffusersライブラリにこだわっています)

https://zenn.dev/asap/articles/e4c199cecf1836

しかしながら、上記の記事の生成画像例からわかるように、ControlNetを利用しない形で生成した画像とControlNetを利用して生成した画像を比較すると、生成される画像のイメージがかなり異なることがわかると思います。

今回は、ControlNetを利用して生成された画像を、再びFLUX.1-devのImg2Imgパイプラインに入力して、構図を維持しながらより質の高い画像にする実験をしてみます。

生成画像例

ControlNetのみの画像

こう見ると、すごい変な画像というわけではないですが、例えば、真ん中の画像とかは、リボンの部分などがボロボロになっていたり、手の指の本数が6本になっていたりします。

3枚目の画像をImg2Img Pipelineに入力し高品質化した画像

上記の画像は、ControlNetのみの画像のうち、3枚目(真ん中)の画像を参照画像として、Img2Img Pipelineにて画像を生成しています。
画像を見る限り、入力画像の構図を維持しながら、より高品質な画像に再構成できていることがわかります。

(参考)ControlNetを使わないで生成した画像

こう見ると、ControlNetを利用すると、画像の質が下がるという意味がわかると思います。
ただし、構図などはバラバラに生成されるので、所望の画像を生成するには、大量の画像を生成して、最も良いものを探したり、プロンプトを工夫する必要があります。

ControlNetを利用することで、簡単に構図を固定できることは、大きなメリットになるため、ControlNetで構図を指定した上で、質の高い画像を生成することを目標にします。

成果物

下記のリポジトリをご覧ください
前回の記事のリポジトリにImage to Imageの機能を追加したものになります)

https://github.com/personabb/colab_AI_sample/tree/main/colab_fluxdev_sample2

事前準備

Hugging Faceのlogin tokenの取得と登録

前回の記事に記載したので隠します

Hugging Faceのlogin tokenの取得と登録

SD3のモデルをローカルで利用可能にするために、Huggingfaceからログイン用のtokenを取得する必要があります。
ログインtokenの取得方法は下記の記事を参考にしてください。
https://zenn.dev/protoout/articles/73-hugging-face-setup

また、取得したログインtokenをGoogle Colabに登録する必要があります。
下記の記事を参考に登録してください。
https://note.com/npaka/n/n79bb63e17685

HF_LOGINという名前で登録してください。

FLUX.1-dev重みへのアクセス準備

前回の記事に記載したので隠します

FLUX.1-dev重みへのアクセス準備

続いて、HuggingfaceでFLUX.1-devの重みにアクセスできるようにします。

https://huggingface.co/black-forest-labs/FLUX.1-dev
上記のURLにアクセスしてください。

初めてアクセスした場合、上記のような画面になると思うので、作成したアカウントでログインしてください。

その後、入力フォームが表示されるかと思います。
そのフォームをすべて埋めて、提出することで、モデル重みにアクセスできるようになります。

リポジトリのクローン

まずは上記のリポジトリをcloneしてください。

./
git clone https://github.com/personabb/colab_AI_sample.git

その後、cloneしたフォルダの中である「colab_AI_sample/colab_fluxdev_sample2」をマイドライブの適当な場所においてください。

Img2Img Pipeline用の参照画像の準備

今回は、前回の記事で生成した画像を流用します。下記の画像になります。

上記の画像をinputs/imagetoimage/inputs1.pngとして格納します

想定されるディレクトリ構造

Google Driveのディレクトリ構造は下記を想定します。

MyDrive/
    └ colab_AI_sample/
          └ colab_fluxdev_sample2/
                  ├ configs/
                  |    └ config.ini
                  ├ inputs/
                  |    ├ imagetoimage/
                  |    |    └ inputs1.png
                  |    ├ controlnet/
                  |    |    └ refer/ (Text2ImageのControlNetで利用。今回不要)
                  |    |    └ refer_prepared/ (Text2ImageのControlNetで利用。今回不要)
                  ├ outputs/
                  ├ module/
                  |    └ module_flux.py
                  └ FluxDev_sample2.ipynb

使い方

実行方法

FluxDev_sample2.ipynbをGoogle Colabratoryアプリで開いて、下記で記載する「パラメータの設定」や「プロンプトの変更」を適切に実施した後、一番上のセルから順番に一番下まで実行すると、画像が5枚「outputs」フォルダに生成されます。

また、最後まで実行後、パラメータを変更して再度実行する場合は、「ランタイム」→「セッションを再起動して全て実行する」をクリックしてください。

パラメータの変更

FluxDev_sample2.ipynbの5セル目と8セル目が該当します。

パラメータの設定 5セル目
#モデルの設定を行う。

config_text = """
[FLUX]
device = auto
n_steps=24
seed=42

;model_mode = t2i
model_mode = i2i

strength = 0.6

;from_single_file = False
;base_model_path = black-forest-labs/FLUX.1-schnell
base_model_path = black-forest-labs/FLUX.1-dev

use_controlnet = False
controlnet_path = Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro

;control_mode0 = canny
;control_mode1 = tile
control_mode2 = depth
;control_mode3 = blur
;control_mode4 = pose
;control_mode5 = gray
;control_mode6 = lq

;lora_weight_repo0 = alfredplpl/flux.1-dev-modern-anime-lora
;lora_weight_path0 = modern-anime-lora-2.safetensors
;lora_weight_repo1 = aleksa-codes/flux-ghibsky-illustration
;lora_weight_path1 = lora.safetensors
;lora_scale0 = 1.0
;lora_scale1 = 1.0
;trigger_word0 = "modern anime style"
;trigger_word1 = "GHIBSKY style painting"
;lora_weight_repo2 = None
;lora_weight_path2 = ./inputs/QTCanimation_lora_v1-PAseer.safetensors
;lora_scale2 = 1.0
;trigger_word2 = ""

;select_solver = LCM
;select_solver = DPM
;select_solver = Eulera
select_solver = FMEuler

use_karras_sigmas = True
scheduler_algorithm_type = dpmsolver++
solver_order = 2

cfg_scale = 3.5
;cfg_scale = 1.0
width = 832
height = 1216
output_type = pil
aesthetic_score = 6
negative_aesthetic_score = 2.5

save_latent_simple = False
save_latent_overstep = False


"""

with open("configs/config.ini", "w", encoding="utf-8") as f:
  f.write(config_text)


  • n_steps
    • 20step以上から綺麗な画像が生成される。公式では50stepを推奨しているが時間がかかる
  • 【新規追加】model_mode
    • t2ii2iを設定します
    • t2iはプロンプトから画像を生成するパイプラインです
    • i2iは画像とプロンプトから画像を生成するパイプラインです。
  • 【新規追加】strength
    • model_mode = i2iを選択した際に、どの程度元の画像にノイズを付与するかを設定する
    • 1の場合はt2iと一致します
  • base_model_path
    • 利用するモデルの指定。FLUX.1-schnellなども指定できる。
    • 今後、高性能はFineTuningモデルが登場したら、そちらを利用することもできる
      • そのモデルが単一のsafetensorモデルの場合はfrom_single_file = Trueを設定する
  • use_controlnet
    • ControlNetを利用するかどうかのFlag
  • controlnet_path
    • 利用するControlNetのモデルパス
  • control_mode0からcontrol_mode6
    • 利用するControlNetのモード。複数記載すれば多重に適用可能
  • lora_weight_repo0からlora_weight_repo9
    • 利用するLoRAのリポジトリ指定。複数記載すれば多重に適用可能
    • HuggingFaceにあるLoRA利用する場合は、ここにリポジトリ名も指定する
  • lora_weight_path0からlora_weight_path9
    • 利用するLoRAのパスを指定する。複数記載すれば多重に適用可能
    • HuggingFaceにあるLoRA利用する場合は、ここにはLoRAで利用するファイル名のみを指定する
    • CivitaiなどのLoRAファイルをダウンロードして利用するファイルは、ファイルを保存した場所の相対パスを記述する('lora_weight_repo`はNoneに設定する)
  • lora_scale0からlora_scale9
    • 利用するLoRAの適用度合い
  • trigger_word0からtrigger_word9
    • 利用するLoRAで推奨されるTrigger Wordの設定
  • select_solver
    • 利用するサンプラーの指定。FLUX.1シリーズの場合はFMEulerを指定する
  • widthheight
    • 生成する画像のサイズ
  • save_latent_simplesave_latent_overstep
    • 生成途中も出力するかどうか、出力する場合どのような形式で出力するかのフラグ
    • save_latent_simpleはノイズから生成される過程を出力
    • save_latent_overstepはある程度綺麗な画像が遷移するような過程を出力

8セル目は下記です。

controlnet_conditioning_scale = [0.5]
temp_strength = None

for i in range(5):
      start = time.time()
      image = flux.generate_image(main_prompt,c_image_path = output_refer_image_folder, i_image_path = i2i_image_path, controlnet_conditioning_scale = controlnet_conditioning_scale, temp_strength = temp_strength)
      print("generate image time: ", time.time()-start)
      image.save("./outputs/FLUX_result_{}.png".format(i))

こちらのcontrolnet_conditioning_scale = [0.5]temp_strength = Noneの部分が変更可能なパラメータです。
controlnet_conditioning_scaleは、ControlNetを利用する場合に、どの程度反映させるかを設定するパラメータです。公式的には0.3-0.8くらいが推奨になっています。

temp_strengthは5セル目のstrengthと同じです。
temp_strengthが設定されている場合(None以外の実数)はそちらが優先されます。

プロンプトの変更

FluxDev_sample2.ipynbの6セル目が該当します。

#読み上げるプロンプトを設定する。


main_prompt = """
modern anime style, a cute Japanese idle singer girl, long Yellowish-white hair with red small ribbon,red eyes,small red hat ,She sings with a smile on her face ,make a peace sign with her left hand, wearing white idle costume, standing on the stage ,This is a flashy concert with 3D holograms and laser effects
"""

#コントロールネットに入力する参照画像は0から6までの数字を想定(0.png,6.webpなど。数字はcontrolnetのモードと一致させてください)
#参照)コントロールネットのモード canny (0), tile (1), depth (2), blur (3), pose (4), gray (5), low quality (6).
input_refer_image_folder = "./inputs/controlnet/refer"
#コントロールネットに入力する変換後の参照画像の格納フォルダ。命名規則は上記と同様
output_refer_image_folder = "./inputs/controlnet/refer_prepared"

#image to imageモードを利用する際の画像path
i2i_image_path = "./inputs/imagetoimage/inputs1.png"

6セル目のmain_promptを変更して、8セル目を実行することで、変更されたプロンプトが反映されて実行されます。

実行結果

ここからは、Google Colabでさまざまパラメータを変更した実験結果を記載します・
プロンプトと参照画像を固定して、strengthを0.6-0.9まで変更した結果を確認します。

設定

今後、明記しない限り、下記の設定を利用する

参照画像

i2i_image_path = "./inputs/imagetoimage/inputs1.png"で指定している画像です。

よく見ると、リボンの部分とかがボロボロになっていることがわかります。
さらに、手の指の本数も6本になっています。
この辺りがどう変化するかが見ものです。

プロンプト

modern anime style, a cute Japanese idle singer girl, long Yellowish-white hair with red small ribbon,red eyes,small red hat ,She sings with a smile on her face ,make a peace sign with her left hand, wearing white idle costume, standing on the stage ,This is a flashy concert with 3D holograms and laser effects

反映させたい構図をプロンプトに盛り込んでいます。
make a peace sign with her left handの部分)
こうすることで、生成された画像がより参照画像の構図に近くなります。

設定

6セル目は下記のように指定します。

#モデルの設定を行う。

config_text = """
[FLUX]
device = auto
n_steps=24
seed=42

;model_mode = t2i
model_mode = i2i

strength = 0.6

base_model_path = black-forest-labs/FLUX.1-dev

use_controlnet = False

select_solver = FMEuler

use_karras_sigmas = True
scheduler_algorithm_type = dpmsolver++
solver_order = 2

cfg_scale = 3.5
width = 832
height = 1216
output_type = pil
aesthetic_score = 6
negative_aesthetic_score = 2.5

save_latent_simple = False
save_latent_overstep = False


"""

with open("configs/config.ini", "w", encoding="utf-8") as f:
  f.write(config_text)

実験1 strength:0.6

8セル目で
temp_strength=0.6として実行します

生成結果

生成された画像は下記です。(一枚にまとめています)

0.6時点でかなり元の画像に忠実ながらも、画像の質が向上していることがわかります。
例えば、崩壊していたリボンは綺麗になっています。
一方で、手の指の本数は大多数の画像で、残念ながら6本のままです

今後、strengthを上げていくことで、元の画像の忠実度を下げていきながら、構図の忠実性と画像の質のバランスを見ていきます。

実験1 strength:0.7

8セル目で
temp_strength=0.7として実行します

生成結果

生成された画像は下記です。(一枚にまとめています)

0.7でも構図に忠実に画像が生成されています。
しかしながら、実験1と比較して画像の質が上がっています。

例えば、
崩壊しているリボンは綺麗に生成されており、また、半分くらいの画像は手の指の本数も5本に修正されています。

実験1 strength:0.8

8セル目で
temp_strength=0.8として実行します

生成結果

生成された画像は下記です。(一枚にまとめています)

多少構図が崩れてきました(ピースの位置がバラバラ)が、まだ大まかな構図は一致しています。
画像の質はどんどん上がっていきます。
4枚の画像で指の本数も5本に修正されています。

実験1 strength:0.9

8セル目で
temp_strength=0.9として実行します

生成結果

生成された画像は下記です。(一枚にまとめています)

ここまでくると、元の参照画像の構図はほぼなく、プロンプトに従って画像が生成されていることがわかります。
画質も通常のFLUX.1-devと同様、非常に高い質になっています。

まとめ

今回は、controlNetを利用して、質が下がってしまった画像に対して、Img2Img Pipelineを適用して画像の質を改善させてみました。

画像の構図を完全に一致させたい場合はstrength=0.7程度のパラメータを指定することで、構図を一致させながら、画像の質が向上できることがわかりました。

今回の記事は前回の記事を見ていただいていることが前提なので、前回の記事も併せてご覧ください。
https://zenn.dev/asap/articles/e4c199cecf1836

このように工夫すると、作りたい構図の質の高い画像が生成できることがわかりました。
元の参照画像の構図を維持しながら、プロンプトの内容を反映させていくことが、ControlNetと似ていますが使い分けができていていいですね。

ではここまで読んでくださってありがとうございました!

コード解説

FluxDev_sample2.ipynb

FluxDev_sample2.ipynbについて解説します。

コードは下記よりご覧ください
https://github.com/personabb/colab_AI_sample/blob/main/colab_fluxdev_sample2/FluxDev_sample2.ipynb

1セル目


#Google Driveのフォルダをマウント(認証入る)
from google.colab import drive
drive.mount('/content/drive')

マイドライブのマウントを行っています。
認証が入り、一定時間認証しないとエラーになってしまうので、最初に持ってきて実行と認証をまとめて実施します。

2セル目


#FLUX で必要なモジュールのインストール
%rm -r /content/diffusers
%cd /content/
!git clone https://github.com/huggingface/diffusers.git
!pip install -U optimum-quanto peft tensorflow-metadata transformers scikit-learn ftfy accelerate invisible_watermark safetensors controlnet-aux mediapipe timm

必要なモジュールをインストールしています。
実験日時点で、pipからインストール可能なDiffusersモジュールの中にFLUX.1を実行できるコードが入っていなかったため、直接GithubからDiffusersリポジトリをクローンして利用しています。
日が経つにつれて、通常のpipでインストールが可能になると思います。

3セル目

from huggingface_hub import login
from google.colab import userdata
HF_LOGIN = userdata.get('HF_LOGIN')
login(HF_LOGIN)

# カレントディレクトリを本ファイルが存在するディレクトリに変更する。
import glob
import os
pwd = os.path.dirname(glob.glob('/content/drive/MyDrive/colabzenn/colab_fluxdev_sample2/FluxDev_sample2.ipynb', recursive=True)[0])
print(pwd)

%cd $pwd
!pwd

import sys
sys.path.append("/content/diffusers/src")

HuggingFaceのログイン用の秘密鍵の取得と、カレントディレクトリの設定と、DiffusersモジュールのPATHの追加を実施しています。
Diffusersモジュールが通常のpipでインストールできるようになりましたら、下記部分は不要になります

import sys
sys.path.append("/content/diffusers/src")

4セル目

#モジュールをimportする
from module.module_flux import FLUX
import time

モジュールをinportしています。
FLUX.1の実装コードはmodule/module_flux.pyで記述しております。
そちらも後述します

5セル目

#モデルの設定を行う。

config_text = """
[FLUX]
device = auto
n_steps=24
seed=42

;model_mode = t2i
model_mode = i2i

strength = 0.6

;from_single_file = False
;base_model_path = black-forest-labs/FLUX.1-schnell
base_model_path = black-forest-labs/FLUX.1-dev

use_controlnet = False
controlnet_path = Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro

;control_mode0 = canny
;control_mode1 = tile
control_mode2 = depth
;control_mode3 = blur
;control_mode4 = pose
;control_mode5 = gray
;control_mode6 = lq

;lora_weight_repo0 = alfredplpl/flux.1-dev-modern-anime-lora
;lora_weight_path0 = modern-anime-lora-2.safetensors
;lora_weight_repo1 = aleksa-codes/flux-ghibsky-illustration
;lora_weight_path1 = lora.safetensors
;lora_scale0 = 1.0
;lora_scale1 = 1.0
;trigger_word0 = "modern anime style"
;trigger_word1 = "GHIBSKY style painting"
;lora_weight_repo2 = None
;lora_weight_path2 = ./inputs/QTCanimation_lora_v1-PAseer.safetensors
;lora_scale2 = 1.0
;trigger_word2 = ""

;select_solver = LCM
;select_solver = DPM
;select_solver = Eulera
select_solver = FMEuler

use_karras_sigmas = True
scheduler_algorithm_type = dpmsolver++
solver_order = 2

cfg_scale = 3.5
;cfg_scale = 1.0
width = 832
height = 1216
output_type = pil
aesthetic_score = 6
negative_aesthetic_score = 2.5

save_latent_simple = False
save_latent_overstep = False


"""

with open("configs/config.ini", "w", encoding="utf-8") as f:
  f.write(config_text)

実験の設定などを記述する部分です。
直接configs/config.iniを書き換えても、このセルを実行したら上書きされてしまうので注意してください。

;をつけるとコメントアウトできるので、今回の実験では不要だが、残しておきたい実験設定などはコメントアウトして残しておくことをお勧めします。

6セル目

#読み上げるプロンプトを設定する。

main_prompt = """
modern anime style, a cute Japanese idle singer girl, long Yellowish-white hair with red small ribbon,red eyes,small red hat ,She sings with a smile on her face ,make a peace sign with her left hand, wearing white idle costume, standing on the stage ,This is a flashy concert with 3D holograms and laser effects
"""

#コントロールネットに入力する参照画像は0から6までの数字を想定(0.png,6.webpなど。数字はcontrolnetのモードと一致させてください)
#参照)コントロールネットのモード canny (0), tile (1), depth (2), blur (3), pose (4), gray (5), low quality (6).
input_refer_image_folder = "./inputs/controlnet/refer"
#コントロールネットに入力する変換後の参照画像の格納フォルダ。命名規則は上記と同様
output_refer_image_folder = "./inputs/controlnet/refer_prepared"

#image to imageモードを利用する際の画像path
i2i_image_path = "./inputs/imagetoimage/inputs1.png"

生成するプロンプトと、ControlNetやImg2Img Pilelineで利用する参照画像の格納場所を記述してます。

7セル目

flux = FLUX()
#指定しているControlNetで利用できる"./inputs/refer_prepared/0.png"などがすでにある場合は、下記はコメントアウトしても良い
#flux.prepare_multi_referimage(input_refer_image_folder = input_refer_image_folder,output_refer_image_folder = output_refer_image_folder, low_threshold = 100, high_threshold = 200, noise_level=25, blur_radius=5)

画像生成AIのFLUX.1のクラスインスタンスを立ち上げています。
立ち上げ時に、モデルのダウンロードと読み込み量子化まで実施されます。

今回はControlNetを利用しないので、下のコードはコメントアウトしています。
ControlNetを利用する場合は、下のコードも実行してください。
使い方は前回の記事を参考にしてください。

8セル目

controlnet_conditioning_scale = [0.5]
temp_strength = 0.6
#temp_strength = None

for i in range(5):
      start = time.time()
      image = flux.generate_image(main_prompt,c_image_path = output_refer_image_folder, i_image_path = i2i_image_path, controlnet_conditioning_scale = controlnet_conditioning_scale, temp_strength = temp_strength)
      print("generate image time: ", time.time()-start)
      image.save("./outputs/FLUX_result_{}.png".format(i))

ここで指定したseed値から連続で5枚分の画像を生成し、生成にかかった時間も計測します。
生成された画像はoutputsフォルダに格納されます。

module/module_flux.py

module/module_flux.pyについて解説します。
コードは下記よりご覧ください
https://github.com/personabb/colab_AI_sample/blob/main/colab_fluxdev_sample2/module/module_flux.py

設定ファイルの中身を取得

class FLUXconfig:
    def __init__(self, config_ini_path = './configs/config.ini'):
        # iniファイルの読み込み
        self.config_ini = configparser.ConfigParser()

        # 指定したiniファイルが存在しない場合、エラー発生
        if not os.path.exists(config_ini_path):
            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), config_ini_path)

        self.config_ini.read(config_ini_path, encoding='utf-8')
        FLUX_items = self.config_ini.items('FLUX')
        self.FLUX_config_dict = dict(FLUX_items)

上記コードでは設定ファイル(./configs/config.ini)の中身を辞書型として取得しています。
設定ファイルはノートブックから書き換えることができますので、普段は気にする必要はありません。

FLUXクラスのコンストラクタ

initメソッド

class FLUX:
    def __init__(self,device = None, config_ini_path = './configs/config.ini'):

        FLUX_config = FLUXconfig(config_ini_path = config_ini_path)
        config_dict = FLUX_config.FLUX_config_dict


        if device is not None:
            self.device = device
        else:
            device = config_dict["device"]

            self.device = "cuda" if torch.cuda.is_available() else "cpu"
            if device != "auto":
                self.device = device

        self.last_latents = None
        self.last_step = -1
        self.last_timestep = 1000

        self.n_steps = int(config_dict["n_steps"])

        self.seed = int(config_dict["seed"])
        self.generator = torch.Generator(device=self.device).manual_seed(self.seed)

        self.use_controlnet = config_dict.get("use_controlnet", "False")
        if self.use_controlnet == "False":
            self.use_controlnet = False
        else:
            self.use_controlnet = True

        self.model_mode = config_dict.get("model_mode", "None")
        if self.model_mode == "t2i":
            self.model_mode = "t2i"
            print("t2i mode")
        elif self.model_mode == "i2i":
            self.model_mode = "i2i"
            print("i2i mode")
        else:
            raise ValueError("model_mode is only 't2i' or 'i2i'.")
        
        self.strength = config_dict.get("strength", "None")
        if self.strength == "None":
            self.strength = None
            if self.model_mode == "i2i":
                raise ValueError("strength is not set")
        else:
            self.strength = float(self.strength)

        self.controlnet_path = config_dict.get("controlnet_path", "None")
        if not self.use_controlnet and self.controlnet_path == "None":
            raise ValueError("controlnet_path is not set")

        self.control_modes_number_list = []
        control_mode_dict = {
            "canny":0,
            "tile":1,
            "depth":2,
            "blur":3,
            "pose":4,
            "gray":5,
            "lq":6
        }
        for i in range(7):
            if config_dict.get(f"control_mode{i}", "None") != "None":
                self.control_modes_number_list.append(control_mode_dict[config_dict[f"control_mode{i}"]])

        self.from_single_file = config_dict.get("from_single_file", "None")
        self.SINGLE_FILE_FLAG = True
        if self.from_single_file != "True":
            self.from_single_file = None
            self.SINGLE_FILE_FLAG = False


        self.base_model_path = config_dict["base_model_path"]


        self.lora_repo_list = []
        self.lora_path_list = []
        self.lora_scale_list = []
        self.lora_trigger_word_list = []
        self.lora_nums = 0
        self.LORA_FLAG = False
        for i in range(10):
            if config_dict.get(f"lora_weight_path{i}", "None") != "None":
                self.LORA_FLAG = True
                #Huggingfaceのrepoから取得する場合は、repoとpath(ファイル名)の両方を指定する
                if config_dict.get(f"lora_weight_repo{i}", "None") != "None":
                    self.lora_repo_list.append(config_dict[f"lora_weight_repo{i}"])
                #Civitaiなどからダウンロードして利用する場合は、ディレクトリ名含むpathのみで指定する
                else:
                    self.lora_repo_list.append(None)
                self.lora_path_list.append(config_dict[f"lora_weight_path{i}"])
                self.lora_nums += 1
                self.lora_scale_list.append(float(config_dict.get(f"lora_scale{i}", "1.0")))
                self.lora_trigger_word_list.append(config_dict.get(f"trigger_word{i}", "None"))


        self.select_solver = config_dict.get("select_solver", "FMEuler")

        self.use_karras_sigmas = config_dict.get("use_karras_sigmas", "True")
        if self.use_karras_sigmas == "True":
            self.use_karras_sigmas = True
        else:
            self.use_karras_sigmas = False

        self.scheduler_algorithm_type = config_dict.get("scheduler_algorithm_type", "dpmsolver++")
        self.solver_order = config_dict.get("solver_order", "None")
        if self.solver_order != "None":
            self.solver_order = int(self.solver_order)
        else:
            self.solver_order = None


        self.cfg_scale = float(config_dict["cfg_scale"])
        self.width = int(config_dict["width"])
        self.height = int(config_dict["height"])
        self.output_type = config_dict["output_type"]
        self.aesthetic_score = float(config_dict["aesthetic_score"])
        self.negative_aesthetic_score = float(config_dict["negative_aesthetic_score"])

        self.save_latent_simple = config_dict["save_latent_simple"]
        if self.save_latent_simple == "True":
            self.save_latent_simple = True
            print("use callback save_latent_simple")
        else:
            self.save_latent_simple = False

        self.save_latent_overstep = config_dict["save_latent_overstep"]
        if self.save_latent_overstep == "True":
            self.save_latent_overstep = True
            print("use callback save_latent_overstep")
        else:
            self.save_latent_overstep = False

        self.use_callback = False
        if self.save_latent_simple or self.save_latent_overstep:
            self.use_callback = True

        if (self.save_latent_simple and self.save_latent_overstep):
            raise ValueError("save_latent_simple and save_latent_overstep cannot be set at the same time")

        if self.model_mode == "t2i":
            if self.LORA_FLAG:
                self.base = self.preprepare_model_withLoRAandQuantize()
            else:
                self.base = self.preprepare_model()
        elif self.model_mode == "i2i":
            if self.LORA_FLAG:
                self.base = self.preprepare_img2img_model_withLoRAandQuantize()
            else:
                self.base = self.preprepare_img2img_model()
        else:
            raise ValueError("model_mode is only 't2i' or 'i2i'.")


設定ファイルを読み込んで、クラスアトリビュートに設定しています。またLoRAやControlNetなど、場合によっては複数読み込まれる可能性のある設定などが、幾つ読み込む必要があるかなどを取得して、別メソットで必要はフラグなどを立てています。

最後に、preprepare_modelにて、モデルのダウンロード、読み込み、量子化などを実施しています。
前回の記事で記載しましたが、2024年9月11日現在、Quantoによる量子化とLoRAを併用するとエラーが出てしまう問題があるようです。
https://github.com/huggingface/diffusers/issues/9270

従って、LoRAを利用するかどうかで、量子化をするかしないかのモデルロードを変えています。
また、Text to ImageとImage to Imageでは利用するPipelineが違うため、それ次第でもモデルロードが違います。

FLUX.1モデルの読み込み(Image to Image without LoRA)

Text to Imageに関しては前回の記事をご覧ください。

preprepare_modelメソッド
class FLUX:
・・・
    def preprepare_img2img_model(self):
        if not self.SINGLE_FILE_FLAG:
            transformer = FluxTransformer2DModel.from_pretrained(
                self.base_model_path,
                subfolder="transformer",
                torch_dtype=torch.float16
            )
            print("transformer quantizing")
            quantize(transformer, weights=qfloat8)
            freeze(transformer)
            print("loaded transformer")

            text_encoder_2 = T5EncoderModel.from_pretrained(
                self.base_model_path,
                subfolder="text_encoder_2",
                torch_dtype=torch.float16
            )
            print("text_encoder_2 quantizing")
            quantize(text_encoder_2, weights=qfloat8)
            freeze(text_encoder_2)
            print("loaded text_encoder_2")

            base = FluxImg2ImgPipeline.from_pretrained(
                self.base_model_path, 
                transformer=transformer, 
                text_encoder_2=text_encoder_2,
                torch_dtype=torch.float16
                )

            print("loaded base model")
        else:
            base = FluxImg2ImgPipeline.from_pretrained(self.base_model_path, torch_dtype=torch.float16)
            print("transformer quantizing")
            quantize(base.transformer, weights=qfloat8)
            freeze(base.transformer)
            print("text_encoder_2 quantizing")
            quantize(base.text_encoder_2, weights=qfloat8)
            freeze(base.text_encoder_2)

            print("loaded base model")

        print("cpu offloading")
        base.enable_model_cpu_offload()

        lora_adapter_name_list = []
        lora_adapter_weights_list = []
        if self.LORA_FLAG:
            for i in range(self.lora_nums):
                if self.lora_repo_list[i] is not None:
                    base.load_lora_weights(
                        pretrained_model_name_or_path_or_dict = self.lora_repo_list[i], 
                        weight_name=self.lora_path_list[i],
                        adapter_name=f"lora{i}")
                else:
                    base.load_lora_weights(self.lora_path_list[i], adapter_name=f"lora{i}")
                lora_adapter_name_list.append(f"lora{i}")
                lora_adapter_weights_list.append(self.lora_scale_list[i])
            if self.lora_nums > 1:
                base.set_adapters(lora_adapter_name_list, adapter_weights=lora_adapter_weights_list)

            print("finish lora settings")

        if self.select_solver == "DPM":
            if self.solver_order is not None:
                base.scheduler = DPMSolverMultistepScheduler.from_config(
                        base.scheduler.config,
                        use_karras_sigmas=self.use_karras_sigmas,
                        Algorithm_type =self.scheduler_algorithm_type,
                        solver_order=self.solver_order,
                        )

            else:
                base.scheduler = DPMSolverMultistepScheduler.from_config(
                        base.scheduler.config,
                        use_karras_sigmas=self.use_karras_sigmas,
                        Algorithm_type =self.scheduler_algorithm_type,
                        )
        elif self.select_solver == "LCM":
            base.scheduler = LCMScheduler.from_config(base.scheduler.config)
        elif self.select_solver == "Eulera":
            base.scheduler = EulerAncestralDiscreteScheduler.from_config(base.scheduler.config)
        elif self.select_solver == "FMEuler":
            base.scheduler = FlowMatchEulerDiscreteScheduler.from_config(base.scheduler.config)
        else:
            raise ValueError("select_solver is only 'DPM' or 'LCM' or 'Eulera' or 'FMEuler'.")


        return base

基本的には前回の記事の内容と同じです。

ただし利用するPipeLineがFluxPipelineではなく、FluxImg2ImgPipelineになります。
LoRAを利用するか、もしくは単一のsafetensorファイルからモデルを読み込むのか、どのスケジューラを利用するのかなど、設定ファイルに則って上から順番にモデルを構築していきます。

FLUX.1による画像生成

generate_imageメソッド
class FLUX:
・・・

    def generate_image(self, prompt, neg_prompt = None, c_image_path = None, i_image_path = None, seed = None, controlnet_conditioning_scale = [1.0], temp_strength = None):
        def decode_tensors(pipe, step, timestep, callback_kwargs):
            if self.save_latent_simple:
                callback_kwargs = decode_tensors_simple(pipe, step, timestep, callback_kwargs)
            elif self.save_latent_overstep:
                callback_kwargs = decode_tensors_residual(pipe, step, timestep, callback_kwargs)
            else:
                raise ValueError("save_latent_simple or save_latent_overstep must be set")
            return callback_kwargs


        def decode_tensors_simple(pipe, step, timestep, callback_kwargs):
            latents = callback_kwargs["latents"]
            imege = None
            if self.save_latent_simple:
                image = latents_to_rgb_vae(latents,pipe)
            else:
                raise ValueError("save_latent_simple is not set")
            gettime = time.time()
            formatted_time_human_readable = time.strftime("%Y%m%d_%H%M%S", time.localtime(gettime))
            image.save(f"./outputs/latent_{formatted_time_human_readable}_{step}_{timestep}.png")

            return callback_kwargs

        def decode_tensors_residual(pipe, step, timestep, callback_kwargs):
            latents = callback_kwargs["latents"]
            if step > 0:
                residual = latents - self.last_latents
                goal = self.last_latents + residual * ((self.last_timestep) / (self.last_timestep - timestep))
                #print( ((self.last_timestep) / (self.last_timestep - timestep)))
            else:
                goal = latents

            if self.save_latent_overstep:
                image = latents_to_rgb_vae(goal,pipe)
            else:
                raise ValueError("save_latent_overstep is not set")

            gettime = time.time()
            formatted_time_human_readable = time.strftime("%Y%m%d_%H%M%S", time.localtime(gettime))
            image.save(f"./outputs/latent_{formatted_time_human_readable}_{step}_{timestep}.png")

            self.last_latents = latents
            self.last_step = step
            self.last_timestep = timestep

            if timestep == 0:
                self.last_latents = None
                self.last_step = -1
                self.last_timestep = 100

            return callback_kwargs

        def latents_to_rgb_vae(latents,pipe):

            latents = pipe._unpack_latents(latents, self.height, self.width, pipe.vae_scale_factor)
            latents = (latents / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor

            image = pipe.vae.decode(latents, return_dict=False)[0]
            image = pipe.image_processor.postprocess(image, output_type=self.output_type)

            return FluxPipelineOutput(images=image).images[0]


        def load_image_path(image_path):
            # もし image_path が str ならファイルパスとして画像を読み込む
            if isinstance(image_path, str):
                image = load_image(image_path)
                print("Image loaded from file path.")
            # もし image_path が PIL イメージならそのまま使用
            elif isinstance(image_path, Image.Image):
                image = image_path
                print("PIL Image object provided.")
            # もし image_path が Torch テンソルならそのまま使用
            elif isinstance(image_path, torch.Tensor):
                image = image_path.unsqueeze(0)
                image = image.permute(0, 3, 1, 2)
                image = image/255.0
                print("Torch Tensor object provided.")
            else:
                raise TypeError("Unsupported type. Provide a file path, PIL Image, or Torch Tensor.")

            return image

        def find_file_with_extension(image_path, i):
            # パターンに一致するファイルを検索
            file_pattern = f"{image_path}/{i}.*"
            matching_files = glob.glob(file_pattern)

            # マッチするファイルが存在する場合、そのファイルのパスを返す
            if matching_files:
                # 例: ./image_path/0.png のような完全なファイルパスが取得される
                return matching_files[0]
            else:
                # マッチするファイルがない場合
                raise FileNotFoundError(f"No file found matching pattern: {file_pattern}")
                return None

        if seed is not None:
            self.generator = torch.Generator(device=self.device).manual_seed(seed)
        if temp_strength is not None:
            self.strength = temp_strength

        control_image_list = []
        if self.use_controlnet and self.model_mode == "t2i":
            print("use controlnet mode: ",self.control_modes_number_list)
            for i in self.control_modes_number_list:
                if c_image_path is None:
                    raise ValueError("when use controlnet ,control_image_path must be set")
                control_image_name = load_image_path(find_file_with_extension(c_image_path, i))
                control_image = load_image_path(control_image_name)
                control_image_list.append(control_image)
        
        init_image = None
        if self.model_mode == "i2i":
            print("use i2i mode")
            if i_image_path is None:
                raise ValueError("when use i2i mode, init_image_path must be set")
            init_image = load_image_path(i_image_path)


        lora_weight_average = 0
        if self.LORA_FLAG:
            print("use LoRA")
            lora_weight_average = sum(self.lora_scale_list) / len(self.lora_scale_list)
            for word in self.lora_trigger_word_list:
                if (word is not None) and (word != "None"):
                    prompt = prompt + ", " + word

        image = None
        if self.model_mode == "i2i":
            if self.use_callback:
                if self.LORA_FLAG:
                    image = self.base(
                        prompt=prompt,
                        image=init_image, 
                        strength= self.strength,
                        guidance_scale=self.cfg_scale,
                        num_inference_steps=self.n_steps,
                        output_type=self.output_type,
                        width = self.width,
                        height = self.height,
                        generator=self.generator,
                        callback_on_step_end=decode_tensors,
                        callback_on_step_end_tensor_inputs=["latents"],
                        joint_attention_kwargs={"scale": lora_weight_average},
                        ).images[0]
                #LORAを利用しない場合
                else:
                    image = self.base(
                        prompt=prompt,
                        image=init_image, 
                        strength= self.strength,
                        guidance_scale=self.cfg_scale,
                        num_inference_steps=self.n_steps,
                        output_type=self.output_type,
                        width = self.width,
                        height = self.height,
                        generator=self.generator,
                        callback_on_step_end=decode_tensors,
                        callback_on_step_end_tensor_inputs=["latents"],
                        ).images[0]
            #latentを保存しない場合
            else:
                if self.LORA_FLAG:
                    image = self.base(
                        prompt=prompt,
                        image=init_image, 
                        strength= self.strength,
                        guidance_scale=self.cfg_scale,
                        num_inference_steps=self.n_steps,
                        output_type=self.output_type,
                        width = self.width,
                        height = self.height,
                        generator=self.generator,
                        joint_attention_kwargs={"scale": lora_weight_average},
                        ).images[0]
                # LORAを利用しない場合
                else:
                    image = self.base(
                        prompt=prompt,
                        image=init_image, 
                        strength= self.strength,
                        guidance_scale=self.cfg_scale,
                        num_inference_steps=self.n_steps,
                        output_type=self.output_type,
                        width = self.width,
                        height = self.height,
                        generator=self.generator
                        ).images[0]

        #t2iモードの場合
        elif self.model_mode == "t2i":
            if self.use_callback:
                if self.LORA_FLAG:
                    if self.use_controlnet:
                        image = self.base(
                            prompt=prompt,
                            control_image=control_image_list,
                            control_mode=self.control_modes_number_list,
                            guidance_scale=self.cfg_scale,
                            controlnet_conditioning_scale=controlnet_conditioning_scale,
                            num_inference_steps=self.n_steps,
                            output_type=self.output_type,
                            width = self.width,
                            height = self.height,
                            generator=self.generator,
                            callback_on_step_end=decode_tensors,
                            callback_on_step_end_tensor_inputs=["latents"],
                            joint_attention_kwargs={"scale": lora_weight_average},
                            ).images[0]
                    else:
                        image = self.base(
                            prompt=prompt,
                            guidance_scale=self.cfg_scale,
                            num_inference_steps=self.n_steps,
                            output_type=self.output_type,
                            width = self.width,
                            height = self.height,
                            generator=self.generator,
                            callback_on_step_end=decode_tensors,
                            callback_on_step_end_tensor_inputs=["latents"],
                            joint_attention_kwargs={"scale": lora_weight_average},
                            ).images[0]
                #LORAを利用しない場合
                else:
                    if self.use_controlnet:
                        image = self.base(
                            prompt=prompt,
                            control_image=control_image_list,
                            control_mode=self.control_modes_number_list,
                            guidance_scale=self.cfg_scale,
                            controlnet_conditioning_scale=controlnet_conditioning_scale,
                            num_inference_steps=self.n_steps,
                            output_type=self.output_type,
                            width = self.width,
                            height = self.height,
                            callback_on_step_end=decode_tensors,
                            callback_on_step_end_tensor_inputs=["latents"],
                            generator=self.generator
                            ).images[0]
                    else:
                        image = self.base(
                            prompt=prompt,
                            guidance_scale=self.cfg_scale,
                            num_inference_steps=self.n_steps,
                            output_type=self.output_type,
                            width = self.width,
                            height = self.height,
                            callback_on_step_end=decode_tensors,
                            callback_on_step_end_tensor_inputs=["latents"],
                            generator=self.generator
                            ).images[0]
            #latentを保存しない場合
            else:
                if self.LORA_FLAG:
                    if self.use_controlnet:
                        image = self.base(
                            prompt=prompt,
                            control_image=control_image_list,
                            control_mode=self.control_modes_number_list,
                            guidance_scale=self.cfg_scale,
                            controlnet_conditioning_scale=controlnet_conditioning_scale,
                            num_inference_steps=self.n_steps,
                            output_type=self.output_type,
                            width = self.width,
                            height = self.height,
                            generator=self.generator,
                            joint_attention_kwargs={"scale": lora_weight_average},
                            ).images[0]
                    else:
                        image = self.base(
                            prompt=prompt,
                            guidance_scale=self.cfg_scale,
                            num_inference_steps=self.n_steps,
                            output_type=self.output_type,
                            width = self.width,
                            height = self.height,
                            generator=self.generator,
                            joint_attention_kwargs={"scale": lora_weight_average},
                            ).images[0]

                # LORAを利用しない場合
                else:
                    if self.use_controlnet:
                        image = self.base(
                            prompt=prompt,
                            control_image=control_image_list,
                            control_mode=self.control_modes_number_list,
                            guidance_scale=self.cfg_scale,
                            controlnet_conditioning_scale=controlnet_conditioning_scale,
                            num_inference_steps=self.n_steps,
                            output_type=self.output_type,
                            width = self.width,
                            height = self.height,
                            generator=self.generator
                            ).images[0]
                    else:
                        image = self.base(
                            prompt=prompt,
                            guidance_scale=self.cfg_scale,
                            num_inference_steps=self.n_steps,
                            output_type=self.output_type,
                            width = self.width,
                            height = self.height,
                            generator=self.generator
                            ).images[0]
        else:
            raise ValueError("model_mode is only 't2i' or 'i2i'.")


        return image

ここでは、FLUXによる画像生成を行っています。
こちらも基本的に前回の記事と同じですが、Image to Imageの場合は、若干__call__メソッドの引数が変わります。

LoRAを導入し、さらに生成途中の画像を出力する場合は、下記が実行されます。
下記では、設定ファイルにて記載されている設定とプロンプト、そしてLoRAの設定と、参照画像の画像データ、ノイズの付与率、そして、生成途中の画像を出力するためのコールバック関数を設定しています。

if self.model_mode == "i2i":
    if self.use_callback:
        if self.LORA_FLAG:
            image = self.base(
                prompt=prompt,
                image=init_image, 
                strength= self.strength,
                guidance_scale=self.cfg_scale,
                num_inference_steps=self.n_steps,
                output_type=self.output_type,
                width = self.width,
                height = self.height,
                generator=self.generator,
                callback_on_step_end=decode_tensors,
                callback_on_step_end_tensor_inputs=["latents"],
                joint_attention_kwargs={"scale": lora_weight_average},
                ).images[0]

まとめ(2回目)

ここまで、実装コードの解説をしていきました。
基本的にはFLUXは前回の記事で説明した通りですが、今回はそこにImage to ImageのPipelineを導入しました。

ここまで読んでくださって、本当にありがとうございました!

Discussion