😍

生成AIをGoogle Colabで簡単に 【FLUX.1-dev + ControlNet + LoRA】【Diffusers】

2024/08/31に公開

はじめに

FLUX.1ってご存知でしょうか。
https://huggingface.co/black-forest-labs/FLUX.1-dev

画像生成AIであるかの有名な「Stable Diffusion」の開発に携わったAI研究者が立ちあげた新しいAI開発企業である「Black Forest Labs」により生み出された高性能な画像生成AIモデルです。
https://blackforestlabs.ai/

生成される画像の質も、他のモデルと比較して現時点で最高性能と言われているモデルでありながら、FLUX.1-devモデルはモデル重みがオープンに公開されていることもあり、LoRAやControlNetといったStable Diffusionにて効果を発揮したモジュールの開発も積極的に進められており、現在熱いAIの一つです。

今回は、FLUX.1シリーズのうち、モデル重みが公開されている中で最高の品質であるFLUX.1-devをGoogle Colaboratoryで実行します。

加えて、現時点で開発途中ではありますが、LoRAやControlNetなども併せて実装していきます。

生成画像例

FLUX.1-devでは下記のような画像を生成できます。
恐ろしいのがFLUX.1-devがベースモデルで下記のレベルのような画像を出力できることです。
よりドメイン特化したFineTuningモデルが出てくるのが期待されます。楽しみです。

通常

ControlNet

Depth

Canny

LoRA

alfredplpl/flux.1-dev-modern-anime-lora

alfredplpl/flux.1-dev-modern-anime-lora + aleksa-codes/flux-ghibsky-illustration

生成途中gif

(3.0MBの制限に収めるために画像サイズ、質ともに圧縮の結果、劣化していることをご了承ください)

成果物

下記のリポジトリをご覧ください。
https://github.com/personabb/colab_AI_sample/tree/main/colab_fluxdev_sample

事前準備

Hugging Faceのlogin tokenの取得と登録

SD3のモデルをローカルで利用可能にするために、Huggingfaceからログイン用のtokenを取得する必要があります。
ログインtokenの取得方法は下記の記事を参考にしてください。
https://zenn.dev/protoout/articles/73-hugging-face-setup

また、取得したログインtokenをGoogle Colabに登録する必要があります。
下記の記事を参考に登録してください。
https://note.com/npaka/n/n79bb63e17685

HF_LOGINという名前で登録してください。

FLUX.1-dev重みへのアクセス準備

続いて、HuggingfaceでFLUX.1-devの重みにアクセスできるようにします。

https://huggingface.co/black-forest-labs/FLUX.1-dev
上記のURLにアクセスしてください。

初めてアクセスした場合、上記のような画面になると思うので、作成したアカウントでログインしてください。

その後、入力フォームが表示されるかと思います。
そのフォームをすべて埋めて、提出することで、モデル重みにアクセスできるようになります。

リポジトリのクローン

まずは上記のリポジトリをcloneしてください。

./
git clone https://github.com/personabb/colab_AI_sample.git

その後、cloneしたフォルダの中である「colab_AI_sample/colab_fluxdev_sample」をマイドライブの適当な場所においてください。

ControlNet用の参照画像を用意

今回利用するControlNetは下記です。
https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro

コントロールネットの入力に利用する画像を取得し、後述するフォルダの「inputs/refer」フォルダに格納してください。

その時に、「(数字).(拡張子)」の形で保存してください
Ex) 「0.png」、「2.webp」など

また、この数字の部分には、ControlNetのモードに合わせて記載してください。
今回利用するコントロールネットのモードは下記です。

canny (0)
tile (1)
depth (2)
blur (3)
pose (4)
gray (5)
low quality (6)

今回の実験では、上記の画像を「0.png」から「6.png」までコピーして「inputs/refer」フォルダに格納しました。

想定されるディレクトリ構造

Google Driveのディレクトリ構造は下記を想定します。

MyDrive/
    └ colab_AI_sample/
          └ colab_fluxdev_sample/
                  ├ configs/
                  |    └ config.ini
                  ├ inputs/
                  |    ├ refer/
                  |    |    ├ 0.png or jpg or webpなど
                  |    |    ├ 1.png or jpg or webpなど
                  |    |    ├ 2.png or jpg or webpなど
                  |    |    | ・・・
                  |    |    └ 6.png or jpg or webpなど
                  |    └ refer_prepared/
                  ├ outputs/
                  ├ module/
                  |    └ module_flux.py
                  └ FluxDev_sample.ipynb

使い方

実行方法

FluxDev_sample.ipynbをGoogle Colabratoryアプリで開いて、一番上のセルから順番に一番下まで実行すると、画像が5枚「outputs」フォルダに生成されます。

また、最後まで実行後、パラメータを変更して再度実行する場合は、「ランタイム」→「セッションを再起動して全て実行する」をクリックしてください。

パラメータの変更

FluxDev_sample.ipynbの5セル目と8セル目が該当します。

#モデルの設定を行う。

config_text = """
[FLUX]
device = auto
n_steps=24
seed=42

;from_single_file = False
;base_model_path = black-forest-labs/FLUX.1-schnell
base_model_path = black-forest-labs/FLUX.1-dev

use_controlnet = False
controlnet_path = Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro

;control_mode0 = canny
;control_mode1 = tile
control_mode2 = depth
;control_mode3 = blur
;control_mode4 = pose
;control_mode5 = gray
;control_mode6 = lq

;lora_weight_repo0 = alfredplpl/flux.1-dev-modern-anime-lora
;lora_weight_path0 = modern-anime-lora_diffusers.safetensors
;lora_weight_repo1 = XLabs-AI/flux-lora-collection
;lora_weight_path1 = scenery_lora.safetensors
;lora_scale0 = 1.0
;lora_scale1 = 1.0
;trigger_word0 = "modern anime style"
;trigger_word1 = "scenery style"

;select_solver = LCM
;select_solver = DPM
;select_solver = Eulera
select_solver = FMEuler

use_karras_sigmas = True
scheduler_algorithm_type = dpmsolver++
solver_order = 2

cfg_scale = 3.5
width = 832
height = 1216
output_type = pil
aesthetic_score = 6
negative_aesthetic_score = 2.5

save_latent_simple = False
save_latent_overstep = False


"""

with open("configs/config.ini", "w", encoding="utf-8") as f:
  f.write(config_text)
  • n_steps
    • 20step以上から綺麗な画像が生成される。公式では50stepを推奨しているが時間がかかる
  • base_model_path
    • 利用するモデルの指定。FLUX.1-schnellなども指定できる。
    • 今後、高性能はFineTuningモデルが登場したら、そちらを利用することもできる
      • そのモデルが単一のsafetensorモデルの場合はfrom_single_file = Trueを設定する
  • use_controlnet
    • ControlNetを利用するかどうかのFlag
  • controlnet_path
    • 利用するControlNetのモデルパス
  • control_mode0からcontrol_mode6
    • 利用するControlNetのモード。複数記載すれば多重に適用可能
  • lora_weight_repo0からlora_weight_repo9
    • 利用するLoRAのリポジトリ指定。複数記載すれば多重に適用可能
    • HuggingFaceにあるLoRA利用する場合は、ここにリポジトリ名も指定する
  • lora_weight_path0からlora_weight_path9
    • 利用するLoRAのパスを指定する。複数記載すれば多重に適用可能
    • HuggingFaceにあるLoRA利用する場合は、ここにはLoRAで利用するファイル名のみを指定する
    • CivitaiなどのLoRAファイルをダウンロードして利用するファイルは、ファイルを保存した場所の相対パスを記述する('lora_weight_repo`はNoneに設定する)
  • lora_scale0からlora_scale9
    • 利用するLoRAの適用度合い
  • trigger_word0からtrigger_word9
    • 利用するLoRAで推奨されるTrigger Wordの設定
  • select_solver
    • 利用するサンプラーの指定。FLUX.1シリーズの場合はFMEulerを指定する
  • widthheight
    • 生成する画像のサイズ
  • save_latent_simplesave_latent_overstep
    • 生成途中も出力するかどうか、出力する場合どのような形式で出力するかのフラグ
    • save_latent_simpleはノイズから生成される過程を出力
    • save_latent_overstepはある程度綺麗な画像が遷移するような過程を出力

8セル目は下記です。

for i in range(5):
      start = time.time()
      image = flux.generate_image(main_prompt,image_path = output_refer_image_folder, controlnet_conditioning_scale = [0.5])
      print("generate image time: ", time.time()-start)
      image.save("./outputs/FLUX_result_{}.png".format(i))

こちらのcontrolnet_conditioning_scale = [0.5]の部分が変更可能なパラメータです。
こちらはControlNetを利用する場合に、どの程度反映させるかを設定するパラメータです。
公式的には0.3-0.8くらいが推奨になっています。

また、複数のControlNetを利用する場合は、下記のように、モードが若い順にリスト形式で指定してください。

controlnet_conditioning_scale = [0.5, 0.8, 0.3, 0.4]

プロンプトの変更

FluxDev_sample.ipynbの6セル目が該当します。

#読み上げるプロンプトを設定する。


main_prompt = """
modern anime style, a cute Japanese idle singer girl, long Yellowish-white hair with red small ribbon,red eyes,small red hat ,She is singing ,holding a microphone, wearing white idle costume, standing on the stage ,This is a flashy concert with 3D holograms and laser effects
"""

#コントロールネットに入力する参照画像は0から6までの数字を想定(0.png,6.webpなど。数字はcontrolnetのモードと一致させてください)
#参照)コントロールネットのモード canny (0), tile (1), depth (2), blur (3), pose (4), gray (5), low quality (6).
input_refer_image_folder = "./inputs/refer"
#コントロールネットに入力する変換後の参照画像の格納フォルダ。命名規則は上記と同様
output_refer_image_folder = "./inputs/refer_prepared"

6セル目のmain_promptを変更して、8セル目を実行することで、変更されたプロンプトが反映されて実行されます。

実行結果

ここからは、Google Colabでさまざまパラメータを変更した実験結果を記載します

実験1

まずはLoRAのControlNetも導入しないノーマルな状態で実験をします。
この時点でRAM最大27GB程度、VRAM14GB程度利用します。
(この実験であればT4 GPUでも動作します。ただしRAMはハイメモリを設定する必要があります)

設定

FluxDev_sample.ipynb(5セル目)
#モデルの設定を行う。

config_text = """
[FLUX]
device = auto
n_steps=24
seed=42

base_model_path = black-forest-labs/FLUX.1-dev

use_controlnet = False

select_solver = FMEuler

use_karras_sigmas = True
scheduler_algorithm_type = dpmsolver++
solver_order = 2

cfg_scale = 3.5
width = 832
height = 1216
output_type = pil
aesthetic_score = 6
negative_aesthetic_score = 2.5

save_latent_simple = False
save_latent_overstep = False


"""

with open("configs/config.ini", "w", encoding="utf-8") as f:
  f.write(config_text)

プロンプト

modern anime style, a cute Japanese idle singer girl, long Yellowish-white hair with red small ribbon,red eyes,small red hat ,She is singing ,holding a microphone, wearing white idle costume, standing on the stage ,This is a flashy concert with 3D holograms and laser effects

生成結果

生成された画像は下記です。(一枚にまとめています)

生成時間は1枚(24step)あたり80秒ほどでした。

実験2

ControlNetを導入してみます。
まずはdepth(2)です。
この時点でRAM30GB程度、VRAM17.5GB程度利用します。
(量子化のタイミングをずらせばRAMはもう少し減らせるかも)

参照画像

変換前と変換後の画像を提示します。
上述した変換前画像をFluxDev_sample.ipynbの7セル目のflux.prepare_multi_referimageにて変換すると右側の画像になります。

設定

FluxDev_sample.ipynb(5セル目)
#モデルの設定を行う。

config_text = """
[FLUX]
device = auto
n_steps=24
seed=42

base_model_path = black-forest-labs/FLUX.1-dev

use_controlnet = True
controlnet_path = Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro

;control_mode0 = canny
;control_mode1 = tile
control_mode2 = depth
;control_mode3 = blur
;control_mode4 = pose
;control_mode5 = gray
;control_mode6 = lq

select_solver = FMEuler

use_karras_sigmas = True
scheduler_algorithm_type = dpmsolver++
solver_order = 2

cfg_scale = 3.5
width = 832
height = 1216
output_type = pil
aesthetic_score = 6
negative_aesthetic_score = 2.5

save_latent_simple = False
save_latent_overstep = False


"""

with open("configs/config.ini", "w", encoding="utf-8") as f:
  f.write(config_text)
FluxDev_sample.ipynb(8セル目)
for i in range(5):
      start = time.time()
      image = flux.generate_image(main_prompt,image_path = output_refer_image_folder, controlnet_conditioning_scale = [0.5])
      print("generate image time: ", time.time()-start)
      image.save("./outputs/FLUX_result_{}.png".format(i))

プロンプト

modern anime style, a cute Japanese idle singer girl, long Yellowish-white hair with red small ribbon,red eyes,small red hat ,She sings with a smile on her face , wearing white idle costume, standing on the stage ,This is a flashy concert with 3D holograms and laser effects

生成結果

生成された画像は下記です。(一枚にまとめています)

なんとなくですが、ControlNetを利用しない方が質が高いような気がしますね。
生成にかかった時間は、1枚(24step)あたり、100秒ほどでした

実験3

ControlNetを導入してみます。
続いてはcanny(0)です。

参照画像

変換前と変換後の画像を提示します。
上述した変換前画像をFluxDev_sample.ipynbの7セル目のflux.prepare_multi_referimageにて変換すると右側の画像になります。

設定

FluxDev_sample.ipynb(5セル目)

#モデルの設定を行う。

config_text = """
[FLUX]
device = auto
n_steps=24
seed=42

base_model_path = black-forest-labs/FLUX.1-dev

use_controlnet = True
controlnet_path = Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro

control_mode0 = canny
;control_mode1 = tile
;control_mode2 = depth
;control_mode3 = blur
;control_mode4 = pose
;control_mode5 = gray
;control_mode6 = lq

select_solver = FMEuler

use_karras_sigmas = True
scheduler_algorithm_type = dpmsolver++
solver_order = 2

cfg_scale = 3.5
width = 832
height = 1216
output_type = pil
aesthetic_score = 6
negative_aesthetic_score = 2.5

save_latent_simple = False
save_latent_overstep = False


"""

with open("configs/config.ini", "w", encoding="utf-8") as f:
  f.write(config_text)


FluxDev_sample.ipynb(8セル目)
for i in range(5):
      start = time.time()
      image = flux.generate_image(main_prompt,image_path = output_refer_image_folder, controlnet_conditioning_scale = [0.5])
      print("generate image time: ", time.time()-start)
      image.save("./outputs/FLUX_result_{}.png".format(i))

プロンプト

modern anime style, a cute Japanese idle singer girl, long Yellowish-white hair with red red hat ,small ribbon,red eyes

生成結果

生成された画像は下記です。(一枚にまとめています)

ちなみに、8セル目でcontrolnet_conditioning_scale = [0.8]を設定した場合は下記のような出力になりました。

うーん。
確かにエッジ画像の通りに出力はされますが、その分ベースモデルが持っていた高品質な画像を生成する能力が犠牲になっているような気がしますね。
(アニメ画像だからかもしれないですね。実写画像を生成する際にはControlNetを利用しても質は下がらないのかも)

そういう観点では、ControlNetを利用しても、生成される画像の質を下げずに完璧に構図を再現できていたSDXLのControlNetはすごいですね。

実験4

続いてはLoRAを試してみます。

今回の実験ではRAM最大40GB程度、VRAM27GB程度利用します。

利用するLoRAは下記です
https://huggingface.co/alfredplpl/flux.1-dev-modern-anime-lora

設定

FluxDev_sample.ipynb(5セル目)

#モデルの設定を行う。

config_text = """
[FLUX]
device = auto
n_steps=24
seed=42

base_model_path = black-forest-labs/FLUX.1-dev

use_controlnet = True
controlnet_path = Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro

control_mode2 = depth

lora_weight_repo0 = alfredplpl/flux.1-dev-modern-anime-lora
lora_weight_path0 = modern-anime-lora-2.safetensors
;lora_weight_repo1 = aleksa-codes/flux-ghibsky-illustration
;lora_weight_path1 = lora.safetensors
lora_scale0 = 1.0
;lora_scale1 = 1.0
trigger_word0 = "modern anime style"
;trigger_word1 = "GHIBSKY style painting"

select_solver = FMEuler

use_karras_sigmas = True
scheduler_algorithm_type = dpmsolver++
solver_order = 2

cfg_scale = 3.5
width = 832
height = 1216
output_type = pil
aesthetic_score = 6
negative_aesthetic_score = 2.5

save_latent_simple = False
save_latent_overstep = False

"""

with open("configs/config.ini", "w", encoding="utf-8") as f:
  f.write(config_text)


FluxDev_sample.ipynb(8セル目)
for i in range(5):
      start = time.time()
      image = flux.generate_image(main_prompt,image_path = output_refer_image_folder, controlnet_conditioning_scale = [0.5])
      print("generate image time: ", time.time()-start)
      image.save("./outputs/FLUX_result_{}.png".format(i))

プロンプト

a cute Japanese idle singer girl, long Yellowish-white hair with red small ribbon,red eyes,small red hat ,She sings with a smile on her face , wearing white idle costume, standing on the stage

ここに利用したLoRAのtrigger wordが最後に追加されます。
(下記の部分が該当します)

trigger_word0 = "modern anime style"
;trigger_word1 = "GHIBSKY style painting"

生成結果

生成された画像は下記です。(一枚にまとめています)

生成時間は1枚(24step)あたり50秒ほどでした。(A100GPU、LoRAのみ)
実験1の画像と比較して、キャラクターの絵がより可愛くなったような気がします。
(その分、生成される画像のキャラクターの統一感が減った気がするので一長一短かもしれないですね)

さらに、2つ目のLoRAも同時に適用させてみました。
追加したLoRAは下記です。
https://huggingface.co/aleksa-codes/flux-ghibsky-illustration

プロンプトは変わらず、設定ファイルは、LoRA部分のコメントアウトを全て外し、下記のようにするだけです

lora_weight_repo0 = alfredplpl/flux.1-dev-modern-anime-lora
lora_weight_path0 = modern-anime-lora-2.safetensors
lora_weight_repo1 = aleksa-codes/flux-ghibsky-illustration
lora_weight_path1 = lora.safetensors
lora_scale0 = 1.0
lora_scale1 = 1.0
trigger_word0 = "modern anime style"
trigger_word1 = "GHIBSKY style painting"

その結果生成された画像は下記になります。

一つしか適応していないLoRAと比較してわかるように、背景部分の書き込みが若干増えたかなと感じます。
ControlNetと比較して、LoRAはかなり使いやすいのではないかなと思いました。
(量子化との不具合がなければL4 GPUで動くのでより使いやすいのですが・・・)

ちなみにControlNetのDepthを併用して崩れてしまった画像は下記です。
理由ご存知の方がいらっしゃればご教授いただけると非常に嬉しいです。

実験5

続いては、生成された画像の生成途中をみていきます。
ノイズ混じりの生成途中と、クリーン画像の生成途中の両方を可視化していきます。
(画像生成にかかる時間が長くなるのでご留意ください)

実験の設定は「実験1」の設定と同様ですが、下記部分のみが違います。

ノイズ混じりの生成途中

5セル目

・・・
save_latent_simple = True
save_latent_overstep = False
・・・

生成された画像は下記です。
(全部のステップで画像として生成されるため、全てをまとめてgif形式で表示しています)

1枚(24step)あたり、354秒ほどかかりました。

クリーン画像の生成途中

5セル目

・・・
save_latent_simple = False
save_latent_overstep = True
・・・

生成された画像は下記です。
(全部のステップで画像として生成されるため、全てをまとめてgif形式で表示しています)

まとめ

FLUX.1-devをDiffusersライブラリを利用して実行してみました。
SD3 Midiumとは比べ物にならないくらいベースモデルとしての性能が高いですね。

もしFLUX.1-devでSDXLのように高性能なFinetuinngモデルやLoRAなどが出てくると、いよいよ世代交代になってくるかなと感じてきました。
(合わせてPCに要求されるスペックも今より段違いにレベルが上がっていくんだろうな・・・)

一方で、量子化した場合にLoRAが利用できなくなる問題や、ControlNetとLoRAを組み合わせると画像が崩れる問題(これは私のせいかもですが)など、まだまだ使いやすいかと言われると課題はあるなと感じました。
(その点SDXLはどう組み合わせても、質を維持してくれるので使いやすいです)

どちらにせよFLUX.1-devはまだ生み出されてから1ヶ月ほどしか経っていない生まれたてのモデルで、これからどんどん研究が進むにつれて、最適な使い方などが見えてくるかなと思いました。

ではここまで読んでくださってありがとうございます。
また、本記事を書くにあたりさまざまな記事を参考にさせていただきました。
感謝申し上げます。
FLUX.1-dev-ControlNet-Union-Pro
画像生成AI FLUX.1 に Diffusers で LoRA を適用してみた
Memory-efficient Diffusion Transformers with Quanto and Diffusers

また、これまでも拡散モデルの理論についての記事やSDXLの記事なども書いているので、ぜひご覧ください!

これ以降は、忘備録的なコードの解説になります。興味がある方だけご覧ください。

コード解説

FluxDev_sample.ipynb

FluxDev_sample.ipynbについて解説します。

コードは下記よりご覧ください。
https://github.com/personabb/colab_AI_sample/blob/main/colab_fluxdev_sample/FluxDev_sample.ipynb

1セル目


#Google Driveのフォルダをマウント(認証入る)
from google.colab import drive
drive.mount('/content/drive')

マイドライブのマウントを行っています。
認証が入り、一定時間認証しないとエラーになってしまうので、最初に持ってきて実行と認証をまとめて実施します。

2セル目


#FLUX で必要なモジュールのインストール
%rm -r /content/diffusers
%cd /content/
!git clone https://github.com/huggingface/diffusers.git
!pip install -U optimum-quanto peft tensorflow-metadata transformers scikit-learn ftfy accelerate invisible_watermark safetensors controlnet-aux mediapipe timm

必要なモジュールをインストールしています。
実験日時点で、pipからインストール可能なDiffusersモジュールの中にFLUX.1を実行できるコードが入っていなかったため、直接GithubからDiffusersリポジトリをクローンして利用しています。
日が経つにつれて、通常のpipでインストールが可能になると思います。

3セル目

from huggingface_hub import login
from google.colab import userdata
HF_LOGIN = userdata.get('HF_LOGIN')
login(HF_LOGIN)

# カレントディレクトリを本ファイルが存在するディレクトリに変更する。
import glob
import os
pwd = os.path.dirname(glob.glob('/content/drive/MyDrive/colabzenn/colab_fluxdev_sample/FluxDev_sample.ipynb', recursive=True)[0])
print(pwd)

%cd $pwd
!pwd

import sys
sys.path.append("/content/diffusers/src")

HuggingFaceのログイン用の秘密鍵の取得と、カレントディレクトリの設定と、DiffusersモジュールのPATHの追加を実施しています。
Diffusersモジュールが通常のpipでインストールできるようになりましたら、下記部分は不要になります

import sys
sys.path.append("/content/diffusers/src")

4セル目

#モジュールをimportする
from module.module_flux import FLUX
import time

モジュールをinportしています。
FLUX.1の実装コードはmodule/module_flux.pyで記述しております。
そちらも後述します

5セル目

#モデルの設定を行う。

config_text = """
[FLUX]
device = auto
n_steps=24
seed=42

;from_single_file = False
;base_model_path = black-forest-labs/FLUX.1-schnell
base_model_path = black-forest-labs/FLUX.1-dev

use_controlnet = False
controlnet_path = Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro

control_mode0 = canny
;control_mode1 = tile
;control_mode2 = depth
;control_mode3 = blur
;control_mode4 = pose
;control_mode5 = gray
;control_mode6 = lq

;lora_weight_repo0 = alfredplpl/flux.1-dev-modern-anime-lora
;lora_weight_path0 = modern-anime-lora-2.safetensors
;lora_weight_repo1 = aleksa-codes/flux-ghibsky-illustration
;lora_weight_path1 = lora.safetensors
;lora_scale0 = 1.0
;lora_scale1 = 1.0
;trigger_word0 = "modern anime style"
;trigger_word1 = "GHIBSKY style painting"

select_solver = FMEuler

use_karras_sigmas = True
scheduler_algorithm_type = dpmsolver++
solver_order = 2

cfg_scale = 3.5
width = 832
height = 1216
output_type = pil
aesthetic_score = 6
negative_aesthetic_score = 2.5

save_latent_simple = True
save_latent_overstep = False


"""

with open("configs/config.ini", "w", encoding="utf-8") as f:
  f.write(config_text)

実験の設定などを記述する部分です。
直接configs/config.iniを書き換えても、このセルを実行したら上書きされてしまうので注意してください。

;をつけるとコメントアウトできるので、今回の実験では不要だが、残しておきたい実験設定などはコメントアウトして残しておくことをお勧めします。

6セル目

#読み上げるプロンプトを設定する。

main_prompt = """
modern anime style, a cute Japanese idle singer girl, long Yellowish-white hair with red small ribbon,red eyes,small red hat ,She is singing ,holding a microphone, wearing white idle costume, standing on the stage ,This is a flashy concert with 3D holograms and laser effects
"""

#コントロールネットに入力する参照画像は0から6までの数字を想定(0.png,6.webpなど。数字はcontrolnetのモードと一致させてください)
#参照)コントロールネットのモード canny (0), tile (1), depth (2), blur (3), pose (4), gray (5), low quality (6).
input_refer_image_folder = "./inputs/refer"
#コントロールネットに入力する変換後の参照画像の格納フォルダ。命名規則は上記と同様
output_refer_image_folder = "./inputs/refer_prepared"

生成するプロンプトと、ControlNetで利用する参照画像の格納場所を記述してます。

7セル目

#読み上げるプロンプトを設定する。

flux = FLUX()
#指定しているControlNetで利用できる"./inputs/refer_prepared/0.png"などがすでにある場合は、下記はコメントアウトしても良い
flux.prepare_multi_referimage(input_refer_image_folder = input_refer_image_folder,output_refer_image_folder = output_refer_image_folder, low_threshold = 100, high_threshold = 200, noise_level=25, blur_radius=5)

画像生成AIのFLUX.1のクラスインスタンスを立ち上げています。
立ち上げ時に、モデルのダウンロードと読み込み量子化まで実施されます。

さらに、flux.prepare_multi_referimageメソッドにて、ControlNetで利用する参照画像を適切な形式に変換します。
例えば、Depth用に格納した参照画像を、深度マップに変換してoutput_refer_image_folderに格納します。
すでに、コントロールネットに入力する上で適切な形式に変換された画像が、適切なファイル名でoutput_refer_image_folderに存在する場合は、このメソッドは実行する必要はありません。

8セル目

for i in range(5):
      start = time.time()
      image = flux.generate_image(main_prompt,image_path = output_refer_image_folder, controlnet_conditioning_scale = [0.8])
      print("generate image time: ", time.time()-start)
      image.save("./outputs/FLUX_result_{}.png".format(i))

ここで指定したseed値から連続で5枚分の画像を生成し、生成にかかった時間も計測します。
生成された画像はoutputsフォルダに格納されます。

module/module_flux.py

module/module_flux.pyについて解説します。

コードは下記よりご覧ください
https://github.com/personabb/colab_AI_sample/blob/main/colab_fluxdev_sample/module/module_flux.py

設定ファイルの中身を取得

class FLUXconfig:
    def __init__(self, config_ini_path = './configs/config.ini'):
        # iniファイルの読み込み
        self.config_ini = configparser.ConfigParser()

        # 指定したiniファイルが存在しない場合、エラー発生
        if not os.path.exists(config_ini_path):
            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), config_ini_path)

        self.config_ini.read(config_ini_path, encoding='utf-8')
        FLUX_items = self.config_ini.items('FLUX')
        self.FLUX_config_dict = dict(FLUX_items)

上記コードでは設定ファイル(./configs/config.ini)の中身を辞書型として取得しています。
設定ファイルはノートブックから書き換えることができますので、普段は気にする必要はありません。

FLUXクラスのコンストラクタ

initメソッド
class FLUX:
    def __init__(self,device = None, config_ini_path = './configs/config.ini'):

        FLUX_config = FLUXconfig(config_ini_path = config_ini_path)
        config_dict = FLUX_config.FLUX_config_dict


        if device is not None:
            self.device = device
        else:
            device = config_dict["device"]

            self.device = "cuda" if torch.cuda.is_available() else "cpu"
            if device != "auto":
                self.device = device

        self.last_latents = None
        self.last_step = -1
        self.last_timestep = 1000

        self.n_steps = int(config_dict["n_steps"])

        self.seed = int(config_dict["seed"])
        self.generator = torch.Generator(device=self.device).manual_seed(self.seed)

        self.use_controlnet = config_dict.get("use_controlnet", "False")
        if self.use_controlnet == "False":
            self.use_controlnet = False
        else:
            self.use_controlnet = True

        self.controlnet_path = config_dict.get("controlnet_path", "None")
        if not self.use_controlnet and self.controlnet_path == "None":
            raise ValueError("controlnet_path is not set")

        self.control_modes_number_list = []
        control_mode_dict = {
            "canny":0,
            "tile":1,
            "depth":2,
            "blur":3,
            "pose":4,
            "gray":5,
            "lq":6
        }
        for i in range(7):
            if config_dict.get(f"control_mode{i}", "None") != "None":
                self.control_modes_number_list.append(control_mode_dict[config_dict[f"control_mode{i}"]])

        self.from_single_file = config_dict.get("from_single_file", "None")
        self.SINGLE_FILE_FLAG = True
        if self.from_single_file != "True":
            self.from_single_file = None
            self.SINGLE_FILE_FLAG = False


        self.base_model_path = config_dict["base_model_path"]


        self.lora_repo_list = []
        self.lora_path_list = []
        self.lora_scale_list = []
        self.lora_trigger_word_list = []
        self.lora_nums = 0
        self.LORA_FLAG = False
        for i in range(10):
            if config_dict.get(f"lora_weight_path{i}", "None") != "None":
                self.LORA_FLAG = True
                #Huggingfaceのrepoから取得する場合は、repoとpath(ファイル名)の両方を指定する
                if config_dict.get(f"lora_weight_repo{i}", "None") != "None":
                    self.lora_repo_list.append(config_dict[f"lora_weight_repo{i}"])
                #Civitaiなどからダウンロードして利用する場合は、ディレクトリ名含むpathのみで指定する
                else:
                    self.lora_repo_list.append(None)
                self.lora_path_list.append(config_dict[f"lora_weight_path{i}"])
                self.lora_nums += 1
                self.lora_scale_list.append(float(config_dict.get(f"lora_scale{i}", "1.0")))
                self.lora_trigger_word_list.append(config_dict.get(f"trigger_word{i}", "None"))


        self.select_solver = config_dict.get("select_solver", "FMEuler")

        self.use_karras_sigmas = config_dict.get("use_karras_sigmas", "True")
        if self.use_karras_sigmas == "True":
            self.use_karras_sigmas = True
        else:
            self.use_karras_sigmas = False

        self.scheduler_algorithm_type = config_dict.get("scheduler_algorithm_type", "dpmsolver++")
        self.solver_order = config_dict.get("solver_order", "None")
        if self.solver_order != "None":
            self.solver_order = int(self.solver_order)
        else:
            self.solver_order = None


        self.cfg_scale = float(config_dict["cfg_scale"])
        self.width = int(config_dict["width"])
        self.height = int(config_dict["height"])
        self.output_type = config_dict["output_type"]
        self.aesthetic_score = float(config_dict["aesthetic_score"])
        self.negative_aesthetic_score = float(config_dict["negative_aesthetic_score"])

        self.save_latent_simple = config_dict["save_latent_simple"]
        if self.save_latent_simple == "True":
            self.save_latent_simple = True
            print("use callback save_latent_simple")
        else:
            self.save_latent_simple = False

        self.save_latent_overstep = config_dict["save_latent_overstep"]
        if self.save_latent_overstep == "True":
            self.save_latent_overstep = True
            print("use callback save_latent_overstep")
        else:
            self.save_latent_overstep = False

        self.use_callback = False
        if self.save_latent_simple or self.save_latent_overstep:
            self.use_callback = True

        if (self.save_latent_simple and self.save_latent_overstep):
            raise ValueError("save_latent_simple and save_latent_overstep cannot be set at the same time")

        if self.LORA_FLAG:
            self.base = self.preprepare_model_withLoRAandQuantize()
        else:
            self.base = self.preprepare_model()

設定ファイルを読み込んで、クラスアトリビュートに設定しています。またLoRAやControlNetなど、場合によっては複数読み込まれる可能性のある設定などが、幾つ読み込む必要があるかなどを取得して、別メソットで必要はフラグなどを立てています。

最後に、preprepare_modelにて、モデルのダウンロード、読み込み、量子化などを実施しています。

FLUX.1モデルの読み込み(without LoRA)

preprepare_modelメソッド
class FLUX:
・・・
    def preprepare_model(self, controlnet_path = None):
        if controlnet_path is not None:
            self.controlnet_path = controlnet_path

        #重みのサイズが大きい順に量子化して、RAMの最大使用量を減らす
        if self.use_controlnet:
            controlnet = FluxControlNetModel.from_pretrained(self.controlnet_path, torch_dtype=torch.float16)
            controlnet = FluxMultiControlNetModel([controlnet])
            print("controlnet quantizing")
            quantize(controlnet, weights=qfloat8)
            freeze(controlnet)
            print("loaded controlnet")

            if not self.SINGLE_FILE_FLAG:
                transformer = FluxTransformer2DModel.from_pretrained(
                    self.base_model_path,
                    subfolder="transformer",
                    torch_dtype=torch.float16
                )
                print("transformer quantizing")
                quantize(transformer, weights=qfloat8)
                freeze(transformer)
                print("loaded transformer")

                text_encoder_2 = T5EncoderModel.from_pretrained(
                    self.base_model_path,
                    subfolder="text_encoder_2",
                    torch_dtype=torch.float16
                )
                print("text_encoder_2 quantizing")
                quantize(text_encoder_2, weights=qfloat8)
                freeze(text_encoder_2)
                print("loaded text_encoder_2")

                base = FluxControlNetPipeline.from_pretrained(
                    self.base_model_path, 
                    transformer=transformer, 
                    text_encoder_2=text_encoder_2,
                    controlnet=controlnet, 
                    torch_dtype=torch.float16
                    )
                print("loaded base model")
            else:
                pipe = FluxPipeline.from_pretrained(self.base_model_path, torch_dtype=torch.float16)

                print("transformer quantizing")
                quantize(pipe.transformer, weights=qfloat8)
                freeze(pipe.transformer)
                print("text_encoder_2 quantizing")
                quantize(pipe.text_encoder_2, weights=qfloat8)
                freeze(pipe.text_encoder_2)

                base = FluxControlNetPipeline(controlnet = controlnet, **pipe.components)
                print("loaded base model")

        else:
            controlnet = None
            if not self.SINGLE_FILE_FLAG:
                transformer = FluxTransformer2DModel.from_pretrained(
                    self.base_model_path,
                    subfolder="transformer",
                    torch_dtype=torch.float16
                )
                print("transformer quantizing")
                quantize(transformer, weights=qfloat8)
                freeze(transformer)
                print("loaded transformer")

                text_encoder_2 = T5EncoderModel.from_pretrained(
                    self.base_model_path,
                    subfolder="text_encoder_2",
                    torch_dtype=torch.float16
                )
                print("text_encoder_2 quantizing")
                quantize(text_encoder_2, weights=qfloat8)
                freeze(text_encoder_2)
                print("loaded text_encoder_2")

                base = FluxPipeline.from_pretrained(
                    self.base_model_path, 
                    transformer=transformer, 
                    text_encoder_2=text_encoder_2,
                    torch_dtype=torch.float16
                    )
                print("loaded base model")

            else:
                base = FluxPipeline.from_pretrained(self.base_model_path, torch_dtype=torch.float16)

                print("transformer quantizing")
                quantize(base.transformer, weights=qfloat8)
                freeze(base.transformer)
                print("text_encoder_2 quantizing")
                quantize(base.text_encoder_2, weights=qfloat8)
                freeze(base.text_encoder_2)

                print("loaded base model")


        print("cpu offloading")
        base.enable_model_cpu_offload()


        lora_adapter_name_list = []
        lora_adapter_weights_list = []
        if self.LORA_FLAG:
            for i in range(self.lora_nums):
                if self.lora_repo_list[i] is not None:
                    base.load_lora_weights(
                        pretrained_model_name_or_path_or_dict = self.lora_repo_list[i], 
                        weight_name=self.lora_path_list[i],
                        adapter_name=f"lora{i}")
                else:
                    base.load_lora_weights(self.lora_path_list[i], adapter_name=f"lora{i}")
                lora_adapter_name_list.append(f"lora{i}")
                lora_adapter_weights_list.append(self.lora_scale_list[i])
            if self.lora_nums > 1:
                base.set_adapters(lora_adapter_name_list, adapter_weights=lora_adapter_weights_list)

            print("finish lora settings")

        if self.select_solver == "DPM":
            if self.solver_order is not None:
                base.scheduler = DPMSolverMultistepScheduler.from_config(
                        base.scheduler.config,
                        use_karras_sigmas=self.use_karras_sigmas,
                        Algorithm_type =self.scheduler_algorithm_type,
                        solver_order=self.solver_order,
                        )

            else:
                base.scheduler = DPMSolverMultistepScheduler.from_config(
                        base.scheduler.config,
                        use_karras_sigmas=self.use_karras_sigmas,
                        Algorithm_type =self.scheduler_algorithm_type,
                        )
        elif self.select_solver == "LCM":
            base.scheduler = LCMScheduler.from_config(base.scheduler.config)
        elif self.select_solver == "Eulera":
            base.scheduler = EulerAncestralDiscreteScheduler.from_config(base.scheduler.config)
        elif self.select_solver == "FMEuler":
            base.scheduler = FlowMatchEulerDiscreteScheduler.from_config(base.scheduler.config)
        else:
            raise ValueError("select_solver is only 'DPM' or 'LCM' or 'Eulera' or 'FMEuler'.")

        return base

ControlNetを利用するか、LoRAを利用するか、もしくは単一のsafetensorファイルからモデルを読み込むのか、どのスケジューラを利用するのかなど、設定ファイルに則って上から順番にモデルを構築していきます。

最低限の設定方法が知りたい場合は、公式実装をご覧ください。
公式のFLUX.1-devの利用方法
公式のFLUX.1-dev+ControlNetの利用方法

下記の部分にてControlNetを読み込んでいます。
加えて読み込んだControlNetをPipelineに組み込む前に量子化を行い、必要なRAM、VRAM量を減らしています

controlnet = FluxControlNetModel.from_pretrained(self.controlnet_path, torch_dtype=torch.float16)
controlnet = FluxMultiControlNetModel([controlnet])
print("controlnet quantizing")
quantize(controlnet, weights=qfloat8)
freeze(controlnet)
print("loaded controlnet")

続いて、下記の部分でFlux.1のモデルを構築していきます。
使用するVRAMやRAMの量を減らすため、なるべく思いモデルから一つ一つ読み込んで、量子化を行なっています。

transformer = FluxTransformer2DModel.from_pretrained(
    self.base_model_path,
    subfolder="transformer",
    torch_dtype=torch.float16
)
print("transformer quantizing")
quantize(transformer, weights=qfloat8)
freeze(transformer)
print("loaded transformer")

text_encoder_2 = T5EncoderModel.from_pretrained(
    self.base_model_path,
    subfolder="text_encoder_2",
    torch_dtype=torch.float16
)
print("text_encoder_2 quantizing")
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)
print("loaded text_encoder_2")

base = FluxControlNetPipeline.from_pretrained(
    self.base_model_path, 
    transformer=transformer, 
    text_encoder_2=text_encoder_2,
    controlnet=controlnet, 
    torch_dtype=torch.float16
    )
print("loaded base model")

これは、単一のsafetensorファイルのモデル重みを利用する場合も同様にモデルを読み込んで量子化を行なっています。

続いて、下記部分でLoRAの設定を行っています。
下記ではself.LORA_FLAGに則って、実行可否が決まります。
上記のフラグは設定ファイルにLoRAの設定があれば、個数にかかわらずTrueになります。

下記では、設定ファイルに記載したLoRAのパスに従い、一つ一つ指定されたscaleに応じてLoRAを読み込み、Pipelineに組み込んでいきます。

lora_adapter_name_list = []
        lora_adapter_weights_list = []
        if self.LORA_FLAG:
            for i in range(self.lora_nums):
                if self.lora_repo_list[i] is not None:
                    base.load_lora_weights(
                        pretrained_model_name_or_path_or_dict = self.lora_repo_list[i], 
                        weight_name=self.lora_path_list[i],
                        adapter_name=f"lora{i}")
                else:
                    base.load_lora_weights(self.lora_path_list[i], adapter_name=f"lora{i}")
                lora_adapter_name_list.append(f"lora{i}")
                lora_adapter_weights_list.append(self.lora_scale_list[i])
            if self.lora_nums > 1:
                base.set_adapters(lora_adapter_name_list, adapter_weights=lora_adapter_weights_list)

            print("finish lora settings")

最後に、下記の通り、スケジューラを設定して、モデル全体をクラスアトリビュートに返却します


if self.select_solver == "DPM":
    if self.solver_order is not None:
        base.scheduler = DPMSolverMultistepScheduler.from_config(
                base.scheduler.config,
                use_karras_sigmas=self.use_karras_sigmas,
                Algorithm_type =self.scheduler_algorithm_type,
                solver_order=self.solver_order,
                )

    else:
        base.scheduler = DPMSolverMultistepScheduler.from_config(
                base.scheduler.config,
                use_karras_sigmas=self.use_karras_sigmas,
                Algorithm_type =self.scheduler_algorithm_type,
                )
elif self.select_solver == "LCM":
    base.scheduler = LCMScheduler.from_config(base.scheduler.config)
elif self.select_solver == "Eulera":
    base.scheduler = EulerAncestralDiscreteScheduler.from_config(base.scheduler.config)
elif self.select_solver == "FMEuler":
    base.scheduler = FlowMatchEulerDiscreteScheduler.from_config(base.scheduler.config)
else:
    raise ValueError("select_solver is only 'DPM' or 'LCM' or 'Eulera' or 'FMEuler'.")

return base

FLUX.1モデルの読み込み(with LoRA)

preprepare_model_withLoRAandQuantizeメソッド
class FLUX:
・・・
    def preprepare_model_withLoRAandQuantize(self, controlnet_path = None):
        if controlnet_path is not None:
            self.controlnet_path = controlnet_path

        #重みのサイズが大きい順に量子化して、RAMの最大使用量を減らす
        if not self.SINGLE_FILE_FLAG:
            transformer = FluxTransformer2DModel.from_pretrained(
                self.base_model_path,
                subfolder="transformer",
                torch_dtype=torch.float16
            )

            text_encoder_2 = T5EncoderModel.from_pretrained(
                self.base_model_path,
                subfolder="text_encoder_2",
                torch_dtype=torch.float16
            )

            base = FluxPipeline.from_pretrained(
                self.base_model_path, 
                transformer=transformer, 
                text_encoder_2=text_encoder_2,
                torch_dtype=torch.float16
                )

            print("loaded base model")
        else:
            base = FluxPipeline.from_pretrained(self.base_model_path, torch_dtype=torch.float16)

            print("loaded base model")


        if self.use_controlnet:
            controlnet = FluxControlNetModel.from_pretrained(self.controlnet_path, torch_dtype=torch.float16)
            controlnet = FluxMultiControlNetModel([controlnet])
            print("loaded controlnet")
            base = FluxControlNetPipeline(controlnet = controlnet, **base.components)

        lora_adapter_name_list = []
        lora_adapter_weights_list = []
        if self.LORA_FLAG:
            for i in range(self.lora_nums):
                if self.lora_repo_list[i] is not None:
                    base.load_lora_weights(
                        pretrained_model_name_or_path_or_dict = self.lora_repo_list[i], 
                        weight_name=self.lora_path_list[i],
                        adapter_name=f"lora{i}")
                else:
                    base.load_lora_weights(self.lora_path_list[i], adapter_name=f"lora{i}")
                lora_adapter_name_list.append(f"lora{i}")
                lora_adapter_weights_list.append(self.lora_scale_list[i])
            if self.lora_nums > 1:
                base.set_adapters(lora_adapter_name_list, adapter_weights=lora_adapter_weights_list)

            print("finish lora settings",f"{self.lora_repo_list = },{lora_adapter_weights_list = }")

        print("cpu offloading")
        base.enable_model_cpu_offload()

        if self.select_solver == "DPM":
            if self.solver_order is not None:
                base.scheduler = DPMSolverMultistepScheduler.from_config(
                        base.scheduler.config,
                        use_karras_sigmas=self.use_karras_sigmas,
                        Algorithm_type =self.scheduler_algorithm_type,
                        solver_order=self.solver_order,
                        )

            else:
                base.scheduler = DPMSolverMultistepScheduler.from_config(
                        base.scheduler.config,
                        use_karras_sigmas=self.use_karras_sigmas,
                        Algorithm_type =self.scheduler_algorithm_type,
                        )
        elif self.select_solver == "LCM":
            base.scheduler = LCMScheduler.from_config(base.scheduler.config)
        elif self.select_solver == "Eulera":
            base.scheduler = EulerAncestralDiscreteScheduler.from_config(base.scheduler.config)
        elif self.select_solver == "FMEuler":
            base.scheduler = FlowMatchEulerDiscreteScheduler.from_config(base.scheduler.config)
        else:
            raise ValueError("select_solver is only 'DPM' or 'LCM' or 'Eulera' or 'FMEuler'.")

        return base

上述した量子化とLoRAの相性が悪い問題が発生している間は、A100 GPUを利用することを前提として、量子化を行わないモデルを用意するメソッドを追加で用意しています。
量子化を実施していないこと以外は、preprepare_modelメソッドと大きく変わりません。

(実装の順番などは入れ替わっていますが、これはなんとかして量子化とLoRAを両立できないか足掻いた名残です・・・)

設定ファイルにてLoRAを利用している場合は、こちらのメソッドがpreprepare_modelメソッドの代わりに実行されます。

ControlNetの参照画像作成

prepare_multi_referimageメソッド、prepare_referimageメソッド
class FLUX:
・・・
    def prepare_multi_referimage(self,input_refer_image_folder,output_refer_image_folder, low_threshold = 100, high_threshold = 200, noise_level=25, blur_radius=5):
        #input_refer_image_folderの中にある画像のpathを全て取得する
        def get_image_paths_sorted_by_filename_number(input_refer_image_folder, output_refer_image_folder):
            # 対象の画像拡張子のリスト
            image_extensions = ['.png', '.jpg', '.jpeg', '.bmp', '.gif', '.webp']

            # フォルダ内のファイルを取得し、画像のみフィルタリング
            image_paths = [
                os.path.join(input_refer_image_folder, f)
                for f in os.listdir(input_refer_image_folder)
                if os.path.isfile(os.path.join(input_refer_image_folder, f)) and os.path.splitext(f)[1].lower() in image_extensions
            ]

            # ファイル名の数字部分でソート(小さい順)
            image_paths.sort(key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))

            # 出力用のフォルダパスに変更したリストを作成
            output_image_paths = [
                os.path.join(output_refer_image_folder, os.path.basename(path))
                for path in image_paths
            ]

            return image_paths, output_image_paths

         #output_refer_image_folderが存在しない場合、作成する
        if not os.path.exists(output_refer_image_folder):
            os.makedirs(output_refer_image_folder)

        input_paths, output_paths = get_image_paths_sorted_by_filename_number(input_refer_image_folder, output_refer_image_folder)

        for input_refer_image_path, output_refer_image_path in zip(input_paths, output_paths):
            mode = int(os.path.splitext(os.path.basename(input_refer_image_path))[0])
            self.prepare_referimage(input_refer_image_path,output_refer_image_path, low_threshold = low_threshold, high_threshold = high_threshold, noise_level=noise_level, blur_radius=blur_radius, mode = mode)




    def prepare_referimage(self,input_refer_image_path,output_refer_image_path, low_threshold = 100, high_threshold = 200, noise_level=25, blur_radius=5, mode = 0):
        #mode = canny (0), tile (1), depth (2), blur (3), pose (4), gray (5), low quality (6).

        def prepare_openpose(input_refer_image_path,output_refer_image_path, mode):

            # 初期画像の準備
            init_image = load_image(input_refer_image_path)
            init_image = init_image.resize((self.width, self.height))

            processor = Processor(mode)
            processed_image = processor(init_image, to_pil=True)

            processed_image.save(output_refer_image_path)


        def prepare_canny(input_refer_image_path,output_refer_image_path, low_threshold = 100, high_threshold = 200):
            init_image = load_image(input_refer_image_path)
            init_image = init_image.resize((self.width, self.height))

            # コントロールイメージを作成するメソッド
            def make_canny_condition(image, low_threshold = 100, high_threshold = 200):
                image = np.array(image)
                image = cv2.Canny(image, low_threshold, high_threshold)
                image = image[:, :, None]
                image = np.concatenate([image, image, image], axis=2)
                return Image.fromarray(image)

            control_image = make_canny_condition(init_image, low_threshold, high_threshold)
            control_image.save(output_refer_image_path)

        def prepare_depthmap(input_refer_image_path,output_refer_image_path):

            # 初期画像の準備
            init_image = load_image(input_refer_image_path)
            init_image = init_image.resize((self.width, self.height))
            processor = Processor("depth_midas")
            depth_image = processor(init_image, to_pil=True)
            depth_image.save(output_refer_image_path)

        def prepare_blur(input_refer_image_path, output_refer_image_path, blur_radius=5):
            init_image = load_image(input_refer_image_path)
            init_image = init_image.resize((self.width, self.height))

            # Blur画像を作成するメソッド
            def make_blur_condition(image, blur_radius=5):
                return image.filter(ImageFilter.GaussianBlur(blur_radius))

            blurred_image = make_blur_condition(init_image, blur_radius)
            blurred_image.save(output_refer_image_path)

        def prepare_grayscale(input_refer_image_path, output_refer_image_path):
            init_image = load_image(input_refer_image_path)
            init_image = init_image.resize((self.width, self.height))

            # グレースケール画像を作成するメソッド
            def make_grayscale_condition(image):
                return image.convert("L")

            grayscale_image = make_grayscale_condition(init_image)
            grayscale_image.save(output_refer_image_path)

        def prepare_noise(input_refer_image_path, output_refer_image_path, noise_level=25):
            init_image = load_image(input_refer_image_path)
            init_image = init_image.resize((self.width, self.height))

            # ノイズ付与画像を作成するメソッド
            def make_noise_condition(image, noise_level=25):
                image_array = np.array(image)
                noise = np.random.normal(0, noise_level, image_array.shape)
                noisy_image = image_array + noise
                noisy_image = np.clip(noisy_image, 0, 255).astype(np.uint8)
                return Image.fromarray(noisy_image)

            noisy_image = make_noise_condition(init_image, noise_level)
            noisy_image.save(output_refer_image_path)



        if mode == 0:
            prepare_canny(input_refer_image_path,output_refer_image_path, low_threshold = low_threshold, high_threshold = high_threshold)
        elif mode == 1:
            init_image = load_image(input_refer_image_path)
            init_image.save(output_refer_image_path)
        elif mode == 2:
            prepare_depthmap(input_refer_image_path,output_refer_image_path)
        elif mode == 3:
            prepare_blur(input_refer_image_path, output_refer_image_path, blur_radius=5)
        elif mode == 4:
            prepare_openpose(input_refer_image_path,output_refer_image_path, mode = "openpose_full")
        elif mode == 5:
            prepare_grayscale(input_refer_image_path, output_refer_image_path)
        elif mode == 6:
            prepare_noise(input_refer_image_path, output_refer_image_path, noise_level=30)
        else:
            raise ValueError("control_mode is not set")

ここでは、inputs/referフォルダに格納されている画像をコントロールネットのモードごとに適切な形式に画像変換して、inputs/refer_preparedフォルダに格納するメソッドです。

適切な形式等のは下記のような画像になります。
左から順に、canny (0.png), tile (1.png), depth (2.png), blur (3.png), pose (4.png), gray (5.png), low quality (6.png)です

prepare_referimageメソッドでは、ある一枚の画像を、コントロールネットのモードごとに適切な形式に変換しています。

prepare_multi_referimageメソッドでは、inputs/referフォルダに格納されている全ての画像に対して、ファイル名の数字からコントロールネットのモードを取得して、prepare_referimageメソッドにて、画像を変換しています。

上記のメソッドはノートブックの7セル目にて呼ばれますが、すでにControlNetに入力する上で適切な画像が適切なファイル名で格納されている場合は、このメソッドは実行する必要はありません。

FLUX.1による画像生成

generate_imageメソッド
class FLUX:
・・・
    def generate_image(self, prompt, neg_prompt = None, image_path = None, seed = None, controlnet_conditioning_scale = [1.0]):
        def decode_tensors(pipe, step, timestep, callback_kwargs):
            if self.save_latent_simple:
                callback_kwargs = decode_tensors_simple(pipe, step, timestep, callback_kwargs)
            elif self.save_latent_overstep:
                callback_kwargs = decode_tensors_residual(pipe, step, timestep, callback_kwargs)
            else:
                raise ValueError("save_latent_simple or save_latent_overstep must be set")
            return callback_kwargs


        def decode_tensors_simple(pipe, step, timestep, callback_kwargs):
            latents = callback_kwargs["latents"]
            imege = None
            if self.save_latent_simple:
                image = latents_to_rgb_vae(latents,pipe)
            else:
                raise ValueError("save_latent_simple is not set")
            gettime = time.time()
            formatted_time_human_readable = time.strftime("%Y%m%d_%H%M%S", time.localtime(gettime))
            image.save(f"./outputs/latent_{formatted_time_human_readable}_{step}_{timestep}.png")

            return callback_kwargs

        def decode_tensors_residual(pipe, step, timestep, callback_kwargs):
            latents = callback_kwargs["latents"]
            if step > 0:
                residual = latents - self.last_latents
                goal = self.last_latents + residual * ((self.last_timestep) / (self.last_timestep - timestep))
                #print( ((self.last_timestep) / (self.last_timestep - timestep)))
            else:
                goal = latents

            if self.save_latent_overstep:
                image = latents_to_rgb_vae(goal,pipe)
            else:
                raise ValueError("save_latent_overstep is not set")

            gettime = time.time()
            formatted_time_human_readable = time.strftime("%Y%m%d_%H%M%S", time.localtime(gettime))
            image.save(f"./outputs/latent_{formatted_time_human_readable}_{step}_{timestep}.png")

            self.last_latents = latents
            self.last_step = step
            self.last_timestep = timestep

            if timestep == 0:
                self.last_latents = None
                self.last_step = -1
                self.last_timestep = 100

            return callback_kwargs

        def latents_to_rgb_vae(latents,pipe):

            latents = pipe._unpack_latents(latents, self.height, self.width, pipe.vae_scale_factor)
            latents = (latents / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor

            image = pipe.vae.decode(latents, return_dict=False)[0]
            image = pipe.image_processor.postprocess(image, output_type=self.output_type)

            return FluxPipelineOutput(images=image).images[0]


        def load_image_path(image_path):
            # もし image_path が str ならファイルパスとして画像を読み込む
            if isinstance(image_path, str):
                image = load_image(image_path)
                print("Image loaded from file path.")
            # もし image_path が PIL イメージならそのまま使用
            elif isinstance(image_path, Image.Image):
                image = image_path
                print("PIL Image object provided.")
            # もし image_path が Torch テンソルならそのまま使用
            elif isinstance(image_path, torch.Tensor):
                image = image_path.unsqueeze(0)
                image = image.permute(0, 3, 1, 2)
                image = image/255.0
                print("Torch Tensor object provided.")
            else:
                raise TypeError("Unsupported type. Provide a file path, PIL Image, or Torch Tensor.")

            return image

        def find_file_with_extension(image_path, i):
            # パターンに一致するファイルを検索
            file_pattern = f"{image_path}/{i}.*"
            matching_files = glob.glob(file_pattern)

            # マッチするファイルが存在する場合、そのファイルのパスを返す
            if matching_files:
                # 例: ./image_path/0.png のような完全なファイルパスが取得される
                return matching_files[0]
            else:
                # マッチするファイルがない場合
                raise FileNotFoundError(f"No file found matching pattern: {file_pattern}")
                return None

        if seed is not None:
            self.generator = torch.Generator(device=self.device).manual_seed(seed)

        control_image_list = []
        if self.use_controlnet:
            print("use controlnet mode: ",self.control_modes_number_list)

            for i in self.control_modes_number_list:
                control_image_name = load_image_path(find_file_with_extension(image_path, i))
                control_image = load_image_path(control_image_name)
                control_image_list.append(control_image)


        lora_weight_average = 0
        if self.LORA_FLAG:
            print("use LoRA")
            lora_weight_average = sum(self.lora_scale_list) / len(self.lora_scale_list)
            for word in self.lora_trigger_word_list:
                if (word is not None) and (word != "None"):
                    prompt = prompt + ", " + word

        image = None
        if self.use_callback:
            if self.LORA_FLAG:
                if self.use_controlnet:
                    image = self.base(
                        prompt=prompt,
                        control_image=control_image_list,
                        control_mode=self.control_modes_number_list,
                        guidance_scale=self.cfg_scale,
                        controlnet_conditioning_scale=controlnet_conditioning_scale,
                        num_inference_steps=self.n_steps,
                        output_type=self.output_type,
                        width = self.width,
                        height = self.height,
                        generator=self.generator,
                        callback_on_step_end=decode_tensors,
                        callback_on_step_end_tensor_inputs=["latents"],
                        joint_attention_kwargs={"scale": lora_weight_average},
                        ).images[0]
                else:
                    image = self.base(
                        prompt=prompt,
                        guidance_scale=self.cfg_scale,
                        num_inference_steps=self.n_steps,
                        output_type=self.output_type,
                        width = self.width,
                        height = self.height,
                        generator=self.generator,
                        callback_on_step_end=decode_tensors,
                        callback_on_step_end_tensor_inputs=["latents"],
                        joint_attention_kwargs={"scale": lora_weight_average},
                        ).images[0]
            #LORAを利用しない場合
            else:
                if self.use_controlnet:
                    image = self.base(
                        prompt=prompt,
                        control_image=control_image_list,
                        control_mode=self.control_modes_number_list,
                        guidance_scale=self.cfg_scale,
                        controlnet_conditioning_scale=controlnet_conditioning_scale,
                        num_inference_steps=self.n_steps,
                        output_type=self.output_type,
                        width = self.width,
                        height = self.height,
                        callback_on_step_end=decode_tensors,
                        callback_on_step_end_tensor_inputs=["latents"],
                        generator=self.generator
                        ).images[0]
                else:
                  image = self.base(
                        prompt=prompt,
                        guidance_scale=self.cfg_scale,
                        num_inference_steps=self.n_steps,
                        output_type=self.output_type,
                        width = self.width,
                        height = self.height,
                        callback_on_step_end=decode_tensors,
                        callback_on_step_end_tensor_inputs=["latents"],
                        generator=self.generator
                        ).images[0]
        #latentを保存しない場合
        else:
            if self.LORA_FLAG:
                if self.use_controlnet:
                    image = self.base(
                        prompt=prompt,
                        control_image=control_image_list,
                        control_mode=self.control_modes_number_list,
                        guidance_scale=self.cfg_scale,
                        controlnet_conditioning_scale=controlnet_conditioning_scale,
                        num_inference_steps=self.n_steps,
                        output_type=self.output_type,
                        width = self.width,
                        height = self.height,
                        generator=self.generator,
                        joint_attention_kwargs={"scale": lora_weight_average},
                        ).images[0]
                else:
                    image = self.base(
                        prompt=prompt,
                        guidance_scale=self.cfg_scale,
                        num_inference_steps=self.n_steps,
                        output_type=self.output_type,
                        width = self.width,
                        height = self.height,
                        generator=self.generator,
                        joint_attention_kwargs={"scale": lora_weight_average},
                        ).images[0]

            # LORAを利用しない場合
            else:
                if self.use_controlnet:
                    image = self.base(
                        prompt=prompt,
                        control_image=control_image_list,
                        control_mode=self.control_modes_number_list,
                        guidance_scale=self.cfg_scale,
                        controlnet_conditioning_scale=controlnet_conditioning_scale,
                        num_inference_steps=self.n_steps,
                        output_type=self.output_type,
                        width = self.width,
                        height = self.height,
                        generator=self.generator
                        ).images[0]
                else:
                    image = self.base(
                        prompt=prompt,
                        guidance_scale=self.cfg_scale,
                        num_inference_steps=self.n_steps,
                        output_type=self.output_type,
                        width = self.width,
                        height = self.height,
                        generator=self.generator
                        ).images[0]

        return image

ここでは、FLUXによる画像生成を行っています。

LoRAやControlNetを導入し、さらに生成途中の画像を出力する場合は、下記が実行されます。
下記では、設定ファイルにて記載されている設定とプロンプト、そしてLoRAの設定と、ControlNetを利用するモードのリストと、参照画像のリスト、そして、生成途中の画像を出力するためのコールバック関数を設定しています。

if self.use_callback:
    if self.LORA_FLAG:
        if self.use_controlnet:
            image = self.base(
                prompt=prompt,
                control_image=control_image_list,
                control_mode=self.control_modes_number_list,
                guidance_scale=self.cfg_scale,
                controlnet_conditioning_scale=controlnet_conditioning_scale,
                num_inference_steps=self.n_steps,
                output_type=self.output_type,
                width = self.width,
                height = self.height,
                generator=self.generator,
                callback_on_step_end=decode_tensors,
                callback_on_step_end_tensor_inputs=["latents"],
                joint_attention_kwargs={"scale": lora_weight_average},
                ).images[0]

ControlNetの参照画像は、設定ファイルの情報から取得したモードから、必要な参照画像のPATHを取得し、find_file_with_extension関数で、拡張子を含む完全なファイル名を取得します。
その後load_image_path関数により、参照画像をロードして、リストに追加していきます。

また、生成途中の画像を出力するコールバック関数は下記のように実装しています

        def decode_tensors(pipe, step, timestep, callback_kwargs):
            if self.save_latent_simple:
                callback_kwargs = decode_tensors_simple(pipe, step, timestep, callback_kwargs)
            elif self.save_latent_overstep:
                callback_kwargs = decode_tensors_residual(pipe, step, timestep, callback_kwargs)
            else:
                raise ValueError("save_latent_simple or save_latent_overstep must be set")
            return callback_kwargs


        def decode_tensors_simple(pipe, step, timestep, callback_kwargs):
            latents = callback_kwargs["latents"]
            imege = None
            if self.save_latent_simple:
                image = latents_to_rgb_vae(latents,pipe)
            else:
                raise ValueError("save_latent_simple is not set")
            gettime = time.time()
            formatted_time_human_readable = time.strftime("%Y%m%d_%H%M%S", time.localtime(gettime))
            image.save(f"./outputs/latent_{formatted_time_human_readable}_{step}_{timestep}.png")

            return callback_kwargs

        def decode_tensors_residual(pipe, step, timestep, callback_kwargs):
            latents = callback_kwargs["latents"]
            if step > 0:
                residual = latents - self.last_latents
                goal = self.last_latents + residual * ((self.last_timestep) / (self.last_timestep - timestep))
                #print( ((self.last_timestep) / (self.last_timestep - timestep)))
            else:
                goal = latents

            if self.save_latent_overstep:
                image = latents_to_rgb_vae(goal,pipe)
            else:
                raise ValueError("save_latent_overstep is not set")

            gettime = time.time()
            formatted_time_human_readable = time.strftime("%Y%m%d_%H%M%S", time.localtime(gettime))
            image.save(f"./outputs/latent_{formatted_time_human_readable}_{step}_{timestep}.png")

            self.last_latents = latents
            self.last_step = step
            self.last_timestep = timestep

            if timestep == 0:
                self.last_latents = None
                self.last_step = -1
                self.last_timestep = 100

            return callback_kwargs

        def latents_to_rgb_vae(latents,pipe):

            latents = pipe._unpack_latents(latents, self.height, self.width, pipe.vae_scale_factor)
            latents = (latents / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor

            image = pipe.vae.decode(latents, return_dict=False)[0]
            image = pipe.image_processor.postprocess(image, output_type=self.output_type)

            return FluxPipelineOutput(images=image).images[0]

基本的にやっていることは、Pipelineに実装されているコールバック機能により、1StepごとにFLUXの潜在表現を取得し、それをFLUXの持つVAEにより復元しています。

save_latent_simpleがTRUEの場合は、潜在表現をそのままVAEに入れて復元しているため、後半のstepまでノイズ混じりの状態が続きます。

一方で、save_latent_overstepがTRUEの場合は、取得した潜在表現に対して少し操作を加えています。

その操作がdecode_tensors_residual関数の下記の部分です

if step > 0:
    residual = latents - self.last_latents
    goal = self.last_latents + residual * ((self.last_timestep) / (self.last_timestep - timestep))

step数が1以上(つまり前回の潜在表現を取得できる状態)において、前回と今回の潜在表現の差分residualを取得しています。
この差分は、本stepにおける潜在表現の変化量(ベクトル)を表しています。

そこで、次の行では、この変化量(ベクトル)のまま最終stepまで更新し続けた場合の潜在表現goalを計算により算出しています。
そして得られた潜在表現goalをVAEにより復元することで、綺麗な画像が各stepで復元されていたというわけです。

この結果から、FLUXにおいて、潜在表現に対する拡散モデルの微分方程式はほとんど直線的な移動によって逆拡散処理を行うことができるということが理解できました。

まとめ(2回目)

ここまで、実装コードの解説をしていきました。
Quantoによる量子化込みでも、LoRAが使えるようになってくれー!

ただ、BitsAndBytesによる量子化をしていた時代から比べると、Quantoによる量子化ができるようになって、わざわざTextEncoderとDiTを分けて実行しなくて良くなったから、そういう点では非常に嬉しいです。

それでは、ここまで長い間お付き合いいただきありがとうございましたー!!!

Discussion