transformersで特定の文字列が出力されたときに生成を止めたい
どうも、@ksterxです。
現在はSpiralAIという会社でインターンをしています。
いきなりですが、みなさんはモデルの生成で次のような事象を経験したことはないでしょうか?
###質問:
富士山の高さは?
###回答:
3776 m
###追加の質問:
では、エベレスト山の高さは?
本当は
###質問:
富士山の高さは?
###回答:
3776 m
と回答してほしいだけなのに、、、
毎回、「###」で始まるなにかを出力するんだよなあ、、、
そこで今回は、Transformersを使ったテキスト生成で、特定の文字列(###
とか)が出力された際に、生成を停止する方法について話したいと思います。
生成プロセスの制御
テキスト生成では、特定の条件下で生成を終了させたい場面がしばしばあります。transformersのGenerationConfigを受け付けるモデルであれば、generate
の引数にrepetition_penalty
やno_repeat_ngram_size
など、いくつかの生成プロセスを制御する方法がありますが、今回の行いたい文字列で制御が可能なstopping_criteria
を渡す方法もあります。transformers側の実装でStoppingCriteria
クラスはあるのですが、ただ、特定の文字列を引数に渡すだけみたいなことはできません。
そこで、今回はstopping_criteria
に渡すクラスの実装を行います。
具体的な実装
Transformersライブラリでテキスト生成を制御するために、StoppingCriteria
を継承したカスタムクラスを作成します。以下のクラスは、指定したstop_tokens
が出現した際に生成を停止するように設計されています。
class GenerationStopper(StoppingCriteria):
def __init__(self, stop_tokens: dict[str, list[int | list[int]]]):
self.stop_token_ids = []
for t in stop_tokens.values():
if any(isinstance(x, list) for x in t): # if t is nested list
for x in t:
self.stop_token_ids.append(torch.tensor(x))
else:
self.stop_token_ids.append(torch.tensor(t))
assert isinstance(t, list) or isinstance(t, int)
self.stop_token_words = stop_tokens.keys()
def __repr__(self):
return f"Stopping words: {self.stop_token_words}"
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
for t in self.stop_token_ids:
if torch.eq(input_ids[0][-len(t) :].to("cpu"), t).all():
return True
return False
@property
def criteria(self):
return StoppingCriteriaList([self])
def format(self, sentence: str):
for w in self.stop_token_words:
if w in sentence[-len(w) :]:
sentence = sentence[: -len(w)]
return sentence
-> PyPIでインストールできるようにしました!
pip install gstop
from gstop import GenerationStopper
stop_tokens
の設定
stop_tokens
の設定には注意が必要です。このクラスでは、↓のようにあえて文字列とそれに対応したidsの辞書を渡すようにしています。
stop_tokens = {"###": [774]}
なぜ、stop_token_ids=tokenizer.encode("###", add_special_tokens=False)
のように、直接渡さないのか。この設定は、トークナイザのエンコードとデコードの不可逆性に起因する問題を避けるために重要です。
2024年1月25日追記
mistral系のtokenizerを使うと###
のidsが774
と27332
で異なることがあります。
(@ken11さんに教えていただきました。)
これに対応するために、stop_tokens
に、ネストしたリストを渡せるように変更しました。
stop_tokens = {"###": [[774], [27332]]}
tokenizerの出力を見てみる
簡単な実験をしてみましょう。
今、2つの改行が続く場合を考えます(\n\n
)。
トークナイザーの出力を見てみると
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
word = "\n\n"
encoded = tokenizer.encode(word, add_special_tokens=False) # add_special_tokens=FalseはBOS等を追加しないようにするため
print(encoded)
この結果は[28705, 13, 13]
です。
あれ、待てよと。13
は\n
を表しますが、28705
お前は何だと。
これが実は、
(空白)なんです。
ここが僕がハマったポイントでした。勝手に追加してくれてありがとう。
一応こちらでも語られていますが、sentencepiece側の問題なよう。
こういったように、tokenizerで文字列をencodeしたものを直接渡すと、所望の動作にならない可能性があるので、このようなまどろっこしいことをしているわけです。(まあ、idレベルで指定しても、tokenizerが前の文字とくっつけて別のidを割り当てるとかもあるので、完璧ではないのですが)
モデルごとに、辞書を定義してもいいかもですね〜
使い方
model.generate
にstopper.criteria
を指定することで、生成を適切に制御できます。
stop_tokens = {"###": [[774], [27332]], "\n\n": [13, 13]}
stopper = GenerationStopper(stop_tokens)
question = """###質問:
富士山の高さは?
###回答:
"""
input_ids = tokenizer.encode(question, add_special_tokens=False)
answer = model.generate(
input_ids,
stopping_criteria=stopper.criteria,
)
answer = tokenizer.decode(answer[0])
print(answer)
###質問:
富士山の高さは?
###回答:
3776 m
###
もし、stop_tokens
自体も要らなければ
answer = stopper.format(answer)
print(answer)
で、stopper.format
を使用することで
###質問:
富士山の高さは?
###回答:
3776 m
のように消すこともできます。
まとめ
今回は、Transformersにおけるテキスト生成の制御方法を詳しく見てきました。もし、生成結果が学習やプロンプトで制御がしきれないと感じたときは試してみてください〜
Discussion