プロンプトは考えたくないけど画像生成がしたい!
まとめ
- Danbooru タグをいい感じに生成・補完してくれる LLM を作ることができた
- データセットやトークナイザーの作成、事前学習、SFT、推論の最適化まで一通り体験できた
作成したもの:
- モデル (SFT):
- モデル (事前学習):
- デモ: 🤗 Space
はじめに
最近いい感じの画像生成 AI が流行ってきていて、プロンプトを指定するといい感じの画像が生成できるようになってきました。
しかし、いい感じの画像生成モデルを使っていてもプロンプトがしっかりしていないといい感じになってくれません!困りました!
DALL-E 3 では画像生成をお願いすると、ChatGPT が指定された情報にさらに詳細な情報を追加し、長いプロンプトにしてから画像を生成するようになっています。
少ない情報から情景を想像して詳しい情報を追加している (実際の生成では英語が使われる)
DALL-E 3 の論文によると、長いキャプションで学習して長いプロンプトで生成するのがめっちゃいい感じらしいです。詳細は次の記事が参考になります。
なので画像生成するときは基本的に長いプロンプトにしたいというモチベーションがあります。
前提
Danbooru タグについて
(知っている人は飛ばしてOK)
上で紹介している DALL-E 3 は自然言語で画像を生成しますが、イラスト系の画像生成モデルではもっぱら Danbooru タグと呼ばれるタグを使用します。これは学習する画像にかなり詳細につけられているタグで、このタグを利用して生成したい画像の要素を指定することになります。
画像生成をする際に指定する Danbooru タグはこのようになります:
1girl, solo, black hair, looking at viewer, upper body
これが
1人の少女、1人のみ、黒髪、カメラ目線、上半身
という意味になります。
上ようにカンマ区切りで使われるのが一般的なのですが。順序に特に意味がなかったり、「誰が何をしているのか」などの情報がわからないという問題点もあります。(今回は特に問題にならないです。)
Danbooru タグにはカテゴリが存在し、今回は次のように呼ぶことにします:
-
レーティングタグ (
rating
): 「一般 (rating:general
)」、「微妙 (rating:sensitive
)」、「際どい (rating:questionable
)」、「露骨 (rating:explicit
)」の4つに分けられる -
版権タグ (
copyright
): 「原神 (genshin impact
)」、「葬送のフリーレン (sousou no frieren
)」など -
キャラクタータグ (
character
): 「胡桃 (hu tao (genshin_impact)
)」、「フリーレン (frieren
)」など -
一般タグ (
general
): 「少女1人 (1girl
)」、「カメラ目線 (looking at viewer
)」など- レーティングの
rating:general
とは無関係
- レーティングの
- メタタグ: 高解像度 (
highres
) などの画像のメタ的な情報を示すタグ。今回は出てこない。
既存の解決策
midjourney, Stable Diffusion 等の画像生成プロンプトを自動で生成したい、手動で入力したくないといったモチベーションは古くから存在しており、いろいろな手法が取られています。まずはそのうちのいくつかを紹介したいと思います。
ランダムに当てはめる
AUTOMATIC1111 の Stable Diffusion Web UI の拡張機能である、sd-dynamic-prompts を使うことでプロンプトにランダム性を与えることができるようになります。
{1girl|1boy}, {black hair|blue hair}, {short hair|medium hair|long hair}, {portrait|upepr body|full body}
このように記述することで、{}
内の |
で区切られた単語がランダムに選ばれてプロンプトが作成されます。上の例では、ランダムで
1girl, black hair, medium hair, full body
1boy, blue hair, short hair, upper body
1girl, blue hair, short hair, portrait
のようになります。この例では候補が非常に少なくなっていますが、同系列のタグ、「人数のタグ」や「髪色のタグ」などでまとめたタグのリスト(ワイルドカード)が存在するので、それを指定することで大量にある選択肢からタグを選ぶことができ、非常にランダムなプロンプトを作成することができます。
しかし、これは機械的にランダムに選ぶことになるため、ナンセンスなタグの組み合わせ (そのシチュエーションでその髪型でそのポーズはおかしい、みたいな) が発生しやすかったり、要素がランダムに選ばれることから 操作が難しい という問題点があります。
LLM を使う
自然言語の生成
HuggingFace で prompt
, stable diffusion prompt
などで検索するとプロンプトを生成する LLM が結構見つかります。
以下は Midjourney や Stable Diffusion のプロンプトで学習されたモデルたちで、自然言語のプロンプトを生成できる LLM です。
私は自然言語で画像生成をあんまりしないのでわからないです。ただ、昔からこういう取り組みがあったということがわかればよいかと思います。
キャプションアップサンプル
DALL-E 3で紹介された、そこそこ話しの通じる LLM にお願いしてプロンプトを長くしてもらう手法です。
上のレポでは HuggingFaceH4/zephyr-7b-alpha という 7B サイズの LLM を使っています。
しかしローカルで画像を生成する場合では、ただでさえ VRAM がカツカツなので GPU に LLM のための場所なんて存在しません。そんな場所があれば画像生成モデルを配置したいです。仮に VRAM を明け渡したとしてもこのサイズではプロンプトの生成が完了する間に画像が10枚くらい生成できる時間が経過してしまいます。
LLM を使うアプローチは良いですが、この手法ではLLMがさすがにオーバースペックすぎるという問題点があります。
Danbooru タグ生成
Danbooru タグにも対応した言語モデルもいくつか存在します。danbooru
や anime prompt
などで検索すると良いと思います。
これらは Danbooru タグで学習されているためそれっぽいものを生成することができますし、普通にいい感じになると思います。
しかし、多くのモデルで以下の問題点があります:
-
勝手に版権要素が挟まる:
- 別に版権要素を生成したくないのに、勝手に版権キャラクター名を差し込まれると迷惑
- 版権タグ (原神
genshin impact
や アークナイツarknights
) を指定してないのに、その版権世界でしか出現しないタグ (原神の「神の目」vision (genshin_impact)
や公式代替衣装official alternate costume
) が頻繁に出てくる -
commentary request
などの画像の見た目と関係のないメタタグが出てくることがある
-
必要な長さ生成できない:
- 長いプロンプトが欲しいのに、短い長さで生成が終わってしまうことがある
- 無理やり長くすると不自然な繰り返しが発生する
-
タグが不正確:
- 多くのモデルはトークナイザーが自然言語向けなので、しばしば存在しないタグや奇妙な文字列を生成することがある
というような問題点があり、(個人的に)便利に使えるものはまだありません。
LLM を作ろう!
そこで、自作しようと思います。
データセットの用意
まずはデータセットを用意します。ここに既に収集済みの Danbooru タグデータセットがあります。
2005年(danbooruサービス開始)~2023年の投稿のうち、
-
score
が 1 以上 - ファイル形式が
png,jpg,webp
のどれか
の投稿のタグなどの情報が含まれます。全部で 600 万件のデータになりますが、画像は入っていないので見た目以上に軽量なデータセットです。
前処理
上で挙げた既存の Danbooru タグ生成モデルの問題点を解決するには前処理が大事になります。
まず、事前の調査として人気なタグのリストを作成しました:
- カテゴリごとにすべてのタグの登場回数を調べた
- 登場回数が 100 回に満たないタグを不人気タグとして学習候補から除外することにする
- それ以外を人気タグリストとして保存
これを次のトークナイザーの作成に利用しました。めったに現れない不人気なタグは、それのために語彙を用意するだけ無駄だと考えたため、学習の対象としないことにしました。
トークナイザーを作る
先にトークナイザーを作ってしまいます。
今までのモデルの問題点として、不自然なタグ・文字列を生成してしまうことがあると述べました。
自然言語に特化したトークナイザーであることが原因の1つになっているのですが、今回は自然言語は取り扱わないので Danbooru タグ特化のトークナイザーを作ることで、トークナイザー側で不自然な文字列の生成を防ぐことにします。
一般的な言語モデルのトークナイザーの作成では、BPE やら Unigram やらで効率のいいトークナイザーを作りますが、今回のデータセットはカンマで区切られたタグの列挙であるため、わざわざこれらを学習するのは 逆に非効率 になります。
先ほど人気タグリストを作成したので、これをトークナイザーの語彙とすることで、1つのタグに対して必ず1つのトークン とが割り当てられるため、そのまま効率のいいトークナイズができる上に、1文字ずつの生成ができないのでどう頑張っても不自然な文字列が発生することはありません。
🤗 tokenizers を用いたトークナイザーの作成コード例:
def load_tags(path: str):
with open(path, "r", encoding="utf-8") as file:
tags_txt = file.read()
tags = tags_txt.split("\n")
return tags
# タグを読み込みます
general_general = load_tags("./popular-tags/general-general.txt")
general_sensitive = load_tags("./popular-tags/general-sensitive.txt")
general_questionable = load_tags("./popular-tags/general-questionable.txt")
general_explicit = load_tags("./popular-tags/general-explicit.txt")
character = load_tags("./popular-tags/character.txt")
copyright = load_tags("./popular-tags/copyright.txt")
# 特殊トークン
special_tokens = [
"<|bos|>",
"<|eos|>",
"<|pad|>",
"<|unknown|>",
"<rating>",
"</rating>",
"<copyright>",
"</copyright>",
"<character>",
"</character>",
"<general>",
"</general>",
]
# 特殊トークンの予約 (NAIがやっていたので真似してみた)
reserved_tokens = [f"<|reserved_{i}|>" for i in range(32)]
# まとめる
all_tags = (
special_tokens
+ reserved_tokens
+ rating_tags
+ copyright
+ character
+ general_general
+ general_sensitive
+ general_questionable
+ general_explicit
)
ここでの all_tags
の長さ(タグの合計数)は 67996 となりました。
from tokenizers import Tokenizer, AddedToken, Regex
from tokenizers.models import WordLevel
from tokenizers.pre_tokenizers import Split
from tokenizers.normalizers import Lowercase
tokenizer = Tokenizer(
# 頭から順番に番号をつける
WordLevel(vocab={tag: i for i, tag in enumerate(all_tags)}, unk_token="<|unknown|>")
)
# 大文字もいらない!
tokenizer.normalizer = Lowercase()
# カンマ区切りの文章をトークナイズできるようにする
tokenizer.pre_tokenizer = Split(
pattern=Regex(r",(?:\s)*"), behavior="removed", invert=False
)
# スペシャルトークンをスペシャルトークンとして登録
tokenizer.add_special_tokens(
[
AddedToken(
content=tag,
)
for tag in special_tokens + reserved_tokens
]
)
# パディングの設定
PAD_TOKEN = "<|pad|>"
tokenizer.enable_padding(pad_token=PAD_TOKEN)
tokenizer.padding
# 保存
tokenizer.save("tokenizer.json")
このトークナイザーでトークナイズするとこんな感じになります
tokenizer.encode(
"1girl, 2girls, aaa, long hair, very long hair, honkai: star rail, arknights, hogeeeeeeeee"
).tokens
# ['1girl',
# '2girls',
# '<|unknown|>',
# 'long hair',
# 'very long hair',
# 'honkai: star rail',
# 'arknights',
# '<|unknown|>']
正しいタグじゃないもの (aaa
や hogeeeeeeeee
) は Unknown トークン (<|unknown|>
) になりますが、それ以外はスペースが含まれていても正しくトークナイズできているのがわかります。
以下で Transformers で扱えるようにします
from transformers import PreTrainedTokenizerFast
pretrained_tokenizer = PreTrainedTokenizerFast(tokenizer_file="tokenizer.json")
# 特殊トークンの設定
pretrained_tokenizer.bos_token = "<|bos|>"
pretrained_tokenizer.eos_token = "<|eos|>"
pretrained_tokenizer.unk_token = "<|unknown|>"
# 保存したりする
pretrained_tokenizer.save_pretrained("./dart-tokenizer-20240219")
作成されたトークナイザー:
途中でいくつか HTML タグのような特殊トークンを定義していましたが、これはタグのカテゴリを明示的に分けるために導入しています。後ろで詳しく説明します。
ここで作成したトークナイザーは、トークナイズとエンコード、分割してそれぞれのトークン ID に割り振ることはできますが、逆にトークン ID からカンマ区切りの文章に戻すことは(適切に)できません。
これは、一般的な文章がカンマ区切りで単語をデコードすることを想定していないためです。そのため、扱いやすい形式でデコードするにはカスタムのトークナイザーを定義する必要があります。
カスタムのトークナイザーを定義する
これは生成時に便利になるために定義するので、学習時にはなくても大丈夫です。
import logging
from typing import List
from pydantic.dataclasses import dataclass
from transformers import PreTrainedTokenizerFast
from tokenizers.decoders import Decoder
logger = logging.getLogger(__name__)
class DartDecoder:
def __init__(self, special_tokens: List[str]):
self.special_tokens = list(special_tokens)
def decode_chain(self, tokens: List[str]) -> List[str]:
new_tokens = []
is_specials = []
for i, token in enumerate(tokens):
is_specials.append(token in self.special_tokens)
if i == 0:
new_tokens.append(token)
continue
# this token or previous token is special
if is_specials[i] or is_specials[i - 1]:
new_tokens.append(token)
continue
new_tokens.append(f", {token}")
return new_tokens
class DartTokenizer(PreTrainedTokenizerFast):
"""Dart tokenizer"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._tokenizer.decoder = Decoder.custom( # type: ignore
DartDecoder(list(self.get_added_vocab().keys()))
)
参考: https://github.com/huggingface/tokenizers/issues/636
ここで定義している DartTokenizer
は PreTrainedTokenizerFast
の拡張ですが、__init__()
で self._tokenizer.decoder
を上書きしています。
self._tokenizer.decoder
では、DartDecoder
を指定しています。普段はデフォルトで設定されている decode_chain
が文章をデコードするときに単語同士をくっつけたり空白で繋げたりする役割をしているのですが、ここではタグをカンマで区切って出力するように実装しています。
スペシャルトークン (<general></general>
等) の前後以外でいい感じにカンマで区切ってくれるようになっています。
これを実際に使用するには、tokenizer_config.json
でファイルとクラスを指定してあげる必要がります。
{
"tokenizer_class": "DartTokenizer",
"auto_map": {
"AutoTokenizer": [
"tokenization_dart.DartTokenizer",
"tokenization_dart.DartTokenizer"
]
},
"added_tokens_decoder": {
"0": {
"content": "<|bos|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
...
tokenizer_class
に作成したクラスの名前、auto_map
には AutoTokenizer.from_pretrained
を使って読み込むときに割り当てるクラスを指定します。
配列の1つ目は AutoTokenizer
、2つ目は AutoTokenizerFast
に該当しますが今回は適当に両方同じものを指定しました。
これを設定することで、
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("トークナイザーのファイルパスかレポ名", trust_remote_code=True)
でカスタムのトークナイザーが使えるようになります。trust_remote_code=True
が必要になるため、一部の Trainer や Inference API では使えなくなってしまうのがデメリットです。
事前学習のためのデータセットフィルタリング
データフィルタリング:
- 一般タグ情報が欠落しているデータの削除
- 版権・キャラクタータグに 不人気タグを含む投稿を削除
- 生成したくないタグの削除
- サイン (
signature
)、透かし (watermark
) などの出てきても困るタグや、作画ミス (artistic error
)、悪い解剖学 (bad anatomy
) などの人体おかしい系タグは画像生成のときにも入っていて欲しくないのでこの段階で取り除いた - 文字 (
text
) などはポスター風などの際に必要になると思ったので、ここでは取り除いていない
- サイン (
- タグが過度に多い投稿を削除
- この段階において、1つの投稿に一般タグが 100 個以上、または版権タグが 5個以上、またはキャラクタータグが 10個以上ついているデータを削除
- これは集合写真のようなイラストが多く、学習するときにトークン数が多いとちょっと困るのと、学習が難しそうだったため取り除いた
これをやったあとは、特殊トークンを交えて以下のように連結します:
<|bos|><rating>rating:sfw, rating:general</rating><copyright>vocaloid</copyright><character>hatsune miku</character><general>1girl, blue hair, ...</general><|eos|>
見やすくすると、
<|bos|>
<rating>rating:sfw, rating:general</rating>
<copyright>A, B</copyright>
<character>C, D, E, F</character>
<general>1girl, blue hair, ...</general>
<|eos|>
となります。HTML タグぽくしたおかげでHTMLのシンタックスハイライトが効いて嬉しいです。
<|bos|>
、<|eos|>
のそれぞれ文章の開始と終了を示すトークンと、カテゴリごとのブロックに分かれています。
<rating> ブロック
レーティングカテゴリになりますが、2つのタグが入っているのと今までに説明してないレーティングタグが登場しているのがわかると思います。
先ほど紹介したレーティングタグに加えて rating:sfw
と rating:nsfw
を用意しました。それぞれ rating:general
と rating:sensitive
、rating:questionable
と rating:nsfw
の親タグとなります。
投稿が rating:sensitive
であれば、自動的に rating:sfw
であることになります。
<copyright> ブロック
版権タグが入るブロックですが、特に版権タグが指定されていなければ空っぽのままになります。
空の場合でもブロックはそのまま残すことで学習が簡単になることと、「特に何も指定されていない」ということを理解してもらうことを期待しています。
<character> ブロック
<copyright>
ブロックと同様です。
<general> ブロック
<copyright>
ブロックとほぼ同様です。
ここは空の場合はフィルタリングされているので空になる可能性はないですが、不人気タグが入る可能性があります。
トークナイザーは不人気タグを学習していないため不人気タグは全て <|unknown|>
になります。 <|unknown|>
トークンは全て除去しています。
一度、不人気タグに関するクリーニングをせずに学習した結果、版権タグやキャラクタータグでも <|unknown|>
が発生して気持ち悪かったので除去しています。
版権・キャラクタータグで <|unknown|>
が含まれていた際にそのタグだけ除去するのではなくその行ごと削除しているのは、その版権タグに基づいた謎の一般タグが混入する可能性があり、予測が難しくなったり生成時に謎のタグが出現しやすくなると考えたからです。
共通事項
独自の順序ルールが存在する <rating>
ブロックを除いた、<copyright>
、<character>
、<general>
ブロック内のタグは一切シャッフルを行わず、アルファベット順のまま配置するようにしています。
これにはいくつか理由がありますが、想像してみるとわかると思いますが、アルファベット順の単語を予測するのとランダムな 6 万単語の中から単語を予測するのとでは圧倒的に前者のほうが簡単になるのと、Danbooru タグには前後関係が一切ないため何かしらの秩序が必要だと考えたからです。
ランダムに生成して欲しいからシャッフルしたくなる気持ちもわかるのですが、今回学習する言語モデルは前までの単語から次の単語を予測するため、今までの単語の情報が次選ぶ単語の決定に寄与しなくなると、何も学習できなくなってしまいます。
実際にシャッフルして学習したところ、1girl, solo, looking at viewer
のような 9 割の投稿についてそうなありきたりなタグだけを生成して終了するモデルが完成した。
ただし、これはタグに関する知識を学習する事前学習における話であり、後述する SFT ではまた異なることをやっています。
事前学習を行う
トークナイザーを自作しているので、既存のモデルからのファインチューンは行うことができません。フルスクラッチで学習することになります。
今回は、OPT (Open Pretrained Transformer) というモデルをベースに、位置埋め込みを取り除いたアーキテクチャで学習しましたが実質 OPT です。OPT を選んだ理由としては、特に目立った特徴がなくてシンプルだったので、とりあえず最初の実験に良さそうだと思ったからです。
オセロニアのチーム編成を Transformer で生成する という面白い記事を読んで、Danbooru タグも順序そんな関係ないし位置埋め込みいらんかーって思って抜いてみました。他にもこの記事にいろいろ影響されているところがあります。
ただ、今回は位置埋め込みを含めたものを学習していないので、位置埋め込みなくても普通にアルファベット順に生成できるんだなあというのがわかったくらいで、比較検証はできていません。
モデルのサイズはデフォルト設定を使ったので、OPT公式の 125M のものと同じかと思います。これより小さい既存の設定を調べてないので、とりまこのサイズで学習しうてみるかという感じです。
学習には 🤗Transformers の Trainer を使いました。torch.compile
がなぜかうまくいかなかったこと以外は特に言うことはないです。
学習エポック数は 1 のみです。Stable Diffusion の LoRA 学習やってると足りなそうな感じがしてきますが、LLM の学習ではこれくらいじゃないと過学習して毎回同じものしか生成できなくなってしまいます。データ量がデータ量なので、自宅の RTX 3070 Ti だと学習に 1 日くらいかかりました。
学習したものがこれになります:
プロンプトに以下をいれると、
<|bos|><rating>rating:sfw, rating:general</rating><copyright>original</copyright><character></character><general>1girl
次のようなものがだいたい 1 秒前後で補完されます (カスタムしたトークナイザーが必要)
ahoge, black hair, blue eyes, blush, closed mouth, ear piercing, earrings, jewelry, looking at viewer, mole, mole under eye, piercing, portrait, shirt, short hair, solo, white shirt</general><|eos|>
今回の実験の面白い(当然といえば当然かも知れないけど)ところが、オセロニアの記事で触れられているような出力タグの制限をつけなくても、前の方にタグ(レーティング、版権、キャラクター)によって、しっかりと出力タグが影響を受けてくれていてとても嬉しかった。特に、rating:general
を指定すれば niji・journey にも怒られないようなタグだけで生成される ので非常に便利です。
はじめはオセロニアの記事を真似して、 「rating:general
が指定されたら安全そうなタグだけ許可するように logits をマスクする」ということや、「一度使ったタグは二回使えないようにする」というような制限を機械的に加えることを考えていたのですが、それを実装する前にとりあえず学習してみたところ、勝手にそのルールを学習してくれたので、Transformer 賢いんだなあというのを実感しました。
しかし、このままではまだ便利とはいえません。アルファベット順でしか生成できないので、先頭の方に s
や w
なんてきたらそこで生成終了です。(今までのモデルはたまたま 1girl
が先頭に来るようになっていたので成り立っていたのかもしれません)
たとえば、
(略)<general>1girl, solo
と指定してしまうと、
(略)<general>1girl, solo, white background</general><|eos|>
となってしまいます。これでは 1girl
で solo
であることを指定したいのに、白背景しか生成されなくなってしまいます。
これはタグの出現順を強く学習していることの裏返しでもありますが、このままでは使い物にならないので、次の SFT でこれを解決します。
前提 - SFT について
SFT (Supervised Fine-Tuning) は完結に言えば、指示応答に従うようにファインチューンすることです。
普通の言語モデルであれば例えば、
### 質問
〇〇 を説明してください
### 応答
それは、...
のうちで、質問の返答である ### 応答
以降を学習します。
この SFT を使う利点として、応答部分のみを学習することができる という点があります。応答以前の指示文が学習しないということが可能になるため、上のほうで述べた、「シャッフルして学習すると何も学習できなくなってしまう」という問題を回避して指示に従うような学習が可能になります。
SFT のためのデータセット作成
少ない要素を指定して、それに関連したタグを補完して欲しいこと等を考慮して次の方針でデータセットを作成します。
- ブロックの構成、順序は変更しない
-
指示の終了・補完の開始を示すタグを導入
- 指示終了タグ以前のタグはシャッフルする
- ただし、重要な要素である
1girl
等の人数タグは高確率でこのシャッフルされる側 (条件入力) に入るようにする - 低確率でシャッフルされない側に入るようにすることで何も入力されなかったときも安定するようにする
- ただし、重要な要素である
- それ以降のタグはアルファベット順で配置
- 指示終了タグ以前のタグはシャッフルする
- 生成されるタグの量を操作できるように、タグ量を表すタグの導入
これを簡易的に示すとこのようなテキストになります:
<|bos|>
<rating>rating:general, rating:sfw</rating>
<copyright>B, A</copyright>
<character>D, B, E, C</character>
<general><|long|>medium hair, 2girls<|input_end|>animal ears, ... yuri</general>
<|eos|>
-
<rating>
: 2つだけしか入りませんが内部でシャッフルして位置の乱れに対する耐性をつけます -
<copyright>
: 同様に内部でシャッフルします -
<character>
: 同様に内部でシャッフルします -
<general>
: 大きく変化した点が2つあります-
<|long|>
: この位置に一般タグの総量を表すタグを配置します。今回は以下の条件で4つ用意しました-
<|very_short|>
: 10個以下 -
<|short|>
: 20個以下 -
<|long|>
: 40個以下 -
<|very_long|>
: 40個以上
-
-
<|input_end|>
: これが入力の終了を意味し、これ以前のタグは学習されません- シャッフル側:
1girl
やno humans
などの人数タグは 95% の確率でこちらに確定で入ります。5% の確率で他のタグと同様に扱われます。 - 補完側: 事前学習同様アルファベット順に配置することで適切に学習できるようにします
- シャッフル側:
-
また、SFTでは年代で学習データをフィルタリングしました。今回使ったデータは 2020年~2023年 のデータに絞っています。
これは、SFT自体そこまで量が必要なさそうな気がするのと、最近流行りの版権・キャラクターの容姿情報を学習しやすくなったり、キャラクターがいない場合でも出現する要素が流行りの感じになることを期待しました。
SFT を行う
先ほどさらっと新しいトークンを追加していましたが、もともとのトークナイザーには存在しない語彙です。
そのため、予約していたスペシャルトークンの枠を使って新しい語彙を設定しました。(手動で書き換えた...)
SFT には 🤗 trl ライブラリ の SFTTrainer を使うことで簡単に行うことができます。
こちらも事前学習同様学習エポック数は 1 のみです。RTX 3070 Ti で 6 時間くらいでした。
特筆すべきことはないのですが、SFTTrainer が学習時にトークナイザーを保存しようとするのですが、上で紹介した Decoder を改造したトークナイザーは保存ができないため、エンコード専用のトークナイザーを指定して学習を行いました。
学習されたモデルはこれです:
次のようなプロンプトをいれると、
(略)<general><|long|>1girl, solo<|input_end|>
このようなものが生成されました
:q, animal ear fluff, animal ears, ass, bare shoulders, black footwear, black hair, blush, bow, bowtie, breasts, brown eyes, bug, butterfly, cleavage, detached collar, fake animal ears, flower, full body, high heels, leotard, long hair, looking at viewer, playboy bunny, rabbit ears, rabbit tail, red bow, red bowtie, red leotard, small breasts, tail, thighhighs, tongue, tongue out, wrist cuffs</general><|eos|>
solo
というアルファベット順だと後ろの方に来るタグを入れましたがちゃんと生成されました!
今度は <|very_short|>
を指定して長さを操作できるか調べてみましょう
(略)<general><|very_short|>1girl, solo<|input_end|>
black eyes, black hair, long hair</general><|eos|>
期待通りにとても短いプロンプトになりました。タグの総量を指定することもちゃんとできているようです!
Optimum で最適化する
🤗 Optimum を利用して推論のためにモデルを最適化してみます。
Optimum とはいろいろな最適化のためのツールが入ったライブラリで、今回は ONNX Runtime 用に変換することを行ってみます。
参考: https://huggingface.co/docs/optimum/onnxruntime/quickstart
optimum ライブラリが必要になるためインストールしていなければインストールします。
pip install "optimum[onnxruntime]""
from optimum.onnxruntime.configuration import (
AutoQuantizationConfig,
)
from optimum.onnxruntime import ORTModelForCausalLM
# 変換したい対象のモデル
MODEL_NAME = "p1atdev/dart-v1-sft"
# 出力先
SAVE_DIR = "./onnx"
# transformers形式から読み込むので export=True が必要
ort_model = ORTModelForCausalLM.from_pretrained(MODEL_NAME, export=True)
# 保存!
quantizer.quantize(save_dir=SAVE_DIR, quantization_config=qconfig)
実行すると SAVE_DIR
に model.onnx
と config.json
が作成されます。
これだけで ONNX に変換できました。
使う際はこのようになります。
from optimum.onnxruntime import ORTModelForCausalLM
from transformers import AutoTokenizer
# トークナイザーが入っている
MODEL_NAME = "p1atdev/dart-v1-sft"
# さっき保存したディレクトリ
SAVE_DIR = "./onnx"
# 変換済みなので export=True は不要
ort_model = ORTModelForCausalLM.from_pretrained(SAVE_DIR)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
# 通常の CausalLM と同じインターフェースで利用できる
with torch.no_grad():
outputs = ort_model.generate(tokenizer("<|bos|>", return_tensors="pt").input_ids)
...
ONNX に変換する際についでに量子化もできるらしいです。
from optimum.onnxruntime.configuration import (
AutoQuantizationConfig,
)
from optimum.onnxruntime import ORTQuantizer, ORTModelForCausalLM
# 変換したい対象のモデル
MODEL_NAME = "p1atdev/dart-v1-sft"
# 出力先
SAVE_DIR = "./onnx"
# transformers形式から読み込むので export=True が必要
ort_model = ORTModelForCausalLM.from_pretrained(MODEL_NAME, export=True)
# 量子化方法を定義する (is_static=Trueはうちの環境では動作しなかった...)
qconfig = AutoQuantizationConfig.arm64(is_static=False, per_channel=False)
quantizer = ORTQuantizer.from_pretrained(ort_model)
# 保存!
quantizer.quantize(save_dir=SAVE_DIR, quantization_config=qconfig)
実行すると SAVE_DIR
に model_quantized.onnx
と config.json
が作成されます。
読み込むときはファイル名指定つきで
# 変換済みなので export=True は不要
ort_model = ORTModelForCausalLM.from_pretrained(SAVE_DIR, file_name="model_quantized.onnx")
と変更するだけで同様に使えます。
プロンプトテンプレートを用意する
SFT も行って、最適化もやって、モデルの方の準備はバッチリですが、まだ問題は残っています。
それは生成するためのプロンプトがスペシャルトークンまみれで複雑になっていることです。慣れてしまえば大丈夫だと思うのですが、知らない人にとってはずいぶんと不親切なプロンプトフォーマットになっています。
そこで、トークナイザーにプロンプトを簡単に構築するためのテンプレートを追加します。
参考: https://huggingface.co/docs/transformers/main/en/chat_templating
インストラクション系のモデルが流行りだしたころに追加された機能だと思うのですが、優しいインターフェースで一切変なトークンに触らずとも複雑怪奇なインストラクション用プロンプトを簡単に構築することができます。機能名からもチャット形式のプロンプトを想定していると思うのですが、使えるもんは使います。
これをやるにはカスタムのトークナイザー定義が必要になりますが、ちょうど先程作ったものがあるのでそれに追加します。
次のようになりました:
import logging
from typing import List
from transformers import PreTrainedTokenizerFast
from tokenizers.decoders import Decoder
logger = logging.getLogger(__name__)
# fmt: off
# https://huggingface.co/docs/transformers/main/en/chat_templating
PROMPT_TEMPLATE = (
"{{ '<|bos|>' }}"
"{{ '<rating>' }}"
"{% if 'rating' not in messages or messages['rating'] is none %}"
"{{ 'rating:sfw, rating:general' }}"
"{% else %}"
"{{ messages['rating'] }}"
"{% endif %}"
"{{ '</rating>' }}"
"{{ '<copyright>' }}"
"{% if 'copyright' not in messages or messages['copyright'] is none %}"
"{{ '' }}"
"{% else %}"
"{{ messages['copyright'] }}"
"{% endif %}"
"{{ '</copyright>' }}"
"{{ '<character>' }}"
"{% if 'character' not in messages or messages['character'] is none %}"
"{{ '' }}"
"{% else %}"
"{{ messages['character'] }}"
"{% endif %}"
"{{ '</character>' }}"
"{{ '<general>' }}"
# length token
"{% if 'length' not in messages or messages['length'] is none %}"
"{{ '<|long|>' }}"
"{% else %}"
"{{ messages['length'] }}"
"{% endif %}"
# general token
"{% if 'general' not in messages or messages['general'] is none %}"
"{{ '' }}"
"{% else %}"
"{{ messages['general'] }}"
"{% endif %}"
"{{ '<|input_end|>' }}"
).strip()
# fmt: on
class DartDecoder:
def __init__(self, special_tokens: List[str]):
self.special_tokens = list(special_tokens)
def decode_chain(self, tokens: List[str]) -> List[str]:
new_tokens = []
is_specials = []
for i, token in enumerate(tokens):
is_specials.append(token in self.special_tokens)
if i == 0:
new_tokens.append(token)
continue
# this token or previous token is special
if is_specials[i] or is_specials[i - 1]:
new_tokens.append(token)
continue
new_tokens.append(f", {token}")
return new_tokens
class DartTokenizer(PreTrainedTokenizerFast):
"""Dart tokenizer"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._tokenizer.decoder = Decoder.custom( # type: ignore
DartDecoder(list(self.get_added_vocab().keys()))
)
@property
def default_chat_template(self):
"""
Danbooru Tags Transformer uses special format prompt to generate danbooru tags.
"""
return PROMPT_TEMPLATE
PROMPT_TEMPLATE
で Jinja 形式のテンプレートを定義し、DartTokenizer
の default_chat_template()
でそれを返しています。
Jinja テンプレートは初めて触ったのですが ChatGPT のおかげでどうにかなりました。やはり持つべきは ChatGPT ですね。
定義は正直どうでもよくてどう使うか知りたいと思います。このように使うことができます。
from transformers import AutoTokenizer
MODEL_NAME = "p1atdev/dart-v1-sft"
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16)
inputs = tokenizer.apply_chat_template({
"rating": "rating:sfw, rating:general",
"copyright": "original",
"character": "",
"general": "1girl",
"length": "<|long|>"
}, tokenize=False)
print(inputs)
実行すると、
<|bos|><rating>rating:sfw, rating:general</rating><copyright>original</copyright><character></character><general><|long|>1girl<|input_end|>
となります。かなりすっきりしたと思います。テンプレートにデフォルト値を設定しているので、今回においては次も同値になります:
inputs = tokenizer.apply_chat_template({
"copyright": "original",
"general": "1girl",
}, tokenize=False)
tokenize=False
を tokenize=True
にすると、そのままトークン ID の配列が返ってくるので、
inputs = tokenizer.apply_chat_template({
"copyright": "original",
"general": "1girl",
}, tokenize=True)
with torch.no_grad():
outputs = model.generate(inputs)
と、すぐに生成に渡すこともできます。
最初の生プロンプトを扱っていると比べてかなり安全・簡単に扱えるようになりましたが、"general"
などにスペシャルトークンを勝手に挿入されるインジェクション等は考慮してないので、そこらへんはどうなるのかわからないです。
画像生成に使ってみる
データセットの用意、トークナイザーの作成、事前学習、SFT、最適化を行うことができました。ここまでくると当初の目的を忘れてしまいそうになりますが、本来の目的は画像生成のプロンプトを考えることを放棄しながらいい感じの画像が欲しいという欲張りな要望を叶えることでした。
実際にそれができているか試してみましょう。モデルは AnimagineXL v3.0 を使います。
ネガティブプロンプトの生成には対応していないので公式推奨のものを使います。
nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name
指定なし
まずはデフォルトのテンプレートから生成してみました
<|bos|><rating>rating:sfw, rating:general</rating><copyright></copyright><character></character><general><|long|><|input_end|>
1girl, animal, blurry, blurry background, bottle, brown hair, cat, closed eyes, depth of field, flower, grass, holding, holding bottle, indoors, leaf, long hair, long sleeves, nemophila (flower), plant, potted plant, puffy long sleeves, puffy sleeves, red flower, solo, standing, rating:general
生成された画像(832x1216 4枚):
1girl, animal ear fluff, animal ears, blonde hair, blue eyes, blurry, blurry background, blush, bow, closed mouth, depth of field, dress, fang, fox ears, fox girl, fox tail, hair between eyes, hair bow, hand up, long hair, long sleeves, looking to the side, portrait, red bow, solo, tail, very long hair, wavy hair, white bow, white dress, rating:general
生成された画像(832x1216 4枚):
軽い指定
ざっくりと荒廃した風景のシーン no humans, scenery, abandoned
を指定して生成してみます。
<|bos|><rating>rating:sfw, rating:general</rating><copyright></copyright><character></character><general><|long|>no humans, scenery, abandoned<|input_end|>
no humans, scenery, abandoned, building, clear sky, bridge, cityscape, cloud, cloudy sky, day, grass, house, landscape, moon, outdoors, railing, road, shirt, sky, tree, vehicle name, white sky, yunomi, rating:general
生成された画像(832x1216 4枚):
no humans, scenery, abandoned, bird, blue theme, cloud, cloudy sky, dilapidated, fence, glowing, glowing eye, jar, outdoors, plant, power lines, rice paddy, rural, sky, star (sky), starry sky, water, rating:general
生成された画像(832x1216 4枚):
所感
生成された画像を見ると、同じ生成されたプロンプトであれば 4 枚とも近い雰囲気になっているのがわかると思います。
これは、プロンプトが長くなり詳細が指定されたことで一貫性が維持される[1]からなので、好みの雰囲気・特徴を持ったプロンプトに当たったらそのプロンプトで画像生成ガチャを大量に回すことで好みの画像が沢山生成できそうですね。
ちなみに、一番最初に載せた DALL-E 3の段ボールの画像二枚も、細かい指定があるおかげで一貫性がある思います。
結論
少し使ってみるとわかるかもしれませんが、入力条件に従うかどうかはわりとまちまちで、そのかわりにいい感じの視覚的特徴の組み合わせを出してくれることのほうが嬉しいかもしれません。それでも、今までの画像生成プロンプト生成モデルにあったような、「版権タグが紛れ込む」「変な文字列出てくる」「モデルがバカでかい」「順序に依存しすぎる」という問題は回避しているかと思います。
このモデルのおかげで、私の長いプロンプトを書くための時間は削減されましたが、代わりにプロンプトガチャが追加されることになったようです。(結果として生成できる画像の枚数も増えるし、品質の追求に時間を割けるので良いことだと思います。)
個人的に、画像データセット作成時のランダムなプロンプト作成に使うといいんじゃないかなと思います。(特にControlNet学習スクリプトなどの解像度が固定されてしまっているやつ等)
プロンプト生成のデモはこちらから試すことができます:
補足 - Danbooru タグの変化について
広く知られているかはわかりませんが、Danbooru タグはわりと頻繁に名称が変わったりカテゴリが移動したりします。
わかりやすいものだと、現在「作画ミス」を表すタグ artisitc error
は、以前は error
というタグだったため、NovelAI v3 等のモデルでは error
が使われています。
このように、収集時期によって学習されるタグが変わってしまうことがあるほか、タグの収集に時間をかけすぎると、収集してる間に変更が入ってしまうこともあります。
今回は収集に実質 3 ヶ月かかっている(一度集めて3ヶ月後に続きを収集した) ため、その中で2つのタグのカテゴリが変化しているのにトークナイザーの語彙の作成時に気づきました。 (そこまで頻出するタグではないので大きな影響はないと思います)
語彙数の不一致で気付いたものは以外にも、まだ確認していない同じ意味のタグの被りがあったり使っているモデルのタグと互換性がない可能性もあるのでそこは注意していただくと良いと思います。(このモデルに限らず、WD v14 Tagger 等も最新の画像生成モデルとはタグが異なっていることがあります)
今回のトークナイザーに関してはは1タグ1トークンになっているので、気に入らないタグがあれば直接名前を書き換えちゃってもいいかもしれません。
快適な画像生成ライフを!
-
一貫性についての NovelAI による解説: https://docs.novelai.net/image-jp/tutorial-charactercreation-jp.html タグを詳細に指定することで出力されるキャラクターの一貫性が確保される。 ↩︎
Discussion
Great work!
These are typo correction suggestions.
Thank you for pointing that out!