🔥

🤗transformersで特定の文字列が出力されたときに生成を止めたい

2024/01/22に公開

どうも、@ksterxです。
現在はSpiral.AIという会社でインターンをしています。

いきなりですが、みなさんはモデルの生成で次のような事象を経験したことはないでしょうか?

###質問:
富士山の高さは?

###回答:
3776 m

###追加の質問:
では、エベレスト山の高さは?

本当は

###質問:
富士山の高さは?

###回答:
3776 m

と回答してほしいだけなのに、、、
毎回、「###」で始まるなにかを出力するんだよなあ、、、

そこで今回は、Transformersを使ったテキスト生成で、特定の文字列(###とか)が出力された際に、生成を停止する方法について話したいと思います。

生成プロセスの制御

テキスト生成では、特定の条件下で生成を終了させたい場面がしばしばあります。transformersのGenerationConfigを受け付けるモデルであれば、generateの引数にrepetition_penaltyno_repeat_ngram_sizeなど、いくつかの生成プロセスを制御する方法がありますが、今回の行いたい文字列で制御が可能なstopping_criteriaを渡す方法もあります。transformers側の実装でStoppingCriteriaクラスはあるのですが、ただ、特定の文字列を引数に渡すだけみたいなことはできません。

そこで、今回はstopping_criteriaに渡すクラスの実装を行います。

具体的な実装

Transformersライブラリでテキスト生成を制御するために、StoppingCriteriaを継承したカスタムクラスを作成します。以下のクラスは、指定したstop_tokensが出現した際に生成を停止するように設計されています。

class GenerationStopper(StoppingCriteria):
    def __init__(self, stop_tokens: dict[str, list[int | list[int]]]):
        self.stop_token_ids = []
        for t in stop_tokens.values():
            if any(isinstance(x, list) for x in t):  # if t is nested list
                for x in t:
                    self.stop_token_ids.append(torch.tensor(x))
            else:
                self.stop_token_ids.append(torch.tensor(t))
            assert isinstance(t, list) or isinstance(t, int)
        self.stop_token_words = stop_tokens.keys()

    def __repr__(self):
        return f"Stopping words: {self.stop_token_words}"

    def __call__(
        self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
    ) -> bool:
        for t in self.stop_token_ids:
            if torch.eq(input_ids[0][-len(t) :].to("cpu"), t).all():
                return True
        return False

    @property
    def criteria(self):
        return StoppingCriteriaList([self])

    def format(self, sentence: str):
        for w in self.stop_token_words:
            if w in sentence[-len(w) :]:
                sentence = sentence[: -len(w)]
        return sentence

-> PyPIでインストールできるようにしました!
https://github.com/ksterx/gstop/tree/main

pip install gstop
from gstop import GenerationStopper

stop_tokensの設定

stop_tokensの設定には注意が必要です。このクラスでは、↓のようにあえて文字列とそれに対応したidsの辞書を渡すようにしています。

stop_tokens = {"###": [774]}

なぜ、stop_token_ids=tokenizer.encode("###", add_special_tokens=False)のように、直接渡さないのか。この設定は、トークナイザのエンコードとデコードの不可逆性に起因する問題を避けるために重要です。

2024年1月25日追記

mistral系のtokenizerを使うと###のidsが77427332で異なることがあります。
@ken11さんに教えていただきました。)
これに対応するために、stop_tokensに、ネストしたリストを渡せるように変更しました。

stop_tokens = {"###": [[774], [27332]]}

tokenizerの出力を見てみる

簡単な実験をしてみましょう。
今、2つの改行が続く場合を考えます(\n\n)。
トークナイザーの出力を見てみると

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
word = "\n\n"
encoded = tokenizer.encode(word, add_special_tokens=False)  # add_special_tokens=FalseはBOS等を追加しないようにするため
print(encoded)

この結果は[28705, 13, 13]です。
あれ、待てよと。13\nを表しますが、28705お前は何だと。
これが実は、 (空白)なんです。
ここが僕がハマったポイントでした。勝手に追加してくれてありがとう。
一応こちらでも語られていますが、sentencepiece側の問題なよう。

こういったように、tokenizerで文字列をencodeしたものを直接渡すと、所望の動作にならない可能性があるので、このようなまどろっこしいことをしているわけです。(まあ、idレベルで指定しても、tokenizerが前の文字とくっつけて別のidを割り当てるとかもあるので、完璧ではないのですが)

モデルごとに、辞書を定義してもいいかもですね〜

使い方

model.generatestopper.criteriaを指定することで、生成を適切に制御できます。

stop_tokens = {"###": [[774], [27332]], "\n\n": [13, 13]}
stopper = GenerationStopper(stop_tokens)

question = """###質問:
富士山の高さは?

###回答:
"""

input_ids = tokenizer.encode(question, add_special_tokens=False)
answer = model.generate(
    input_ids,
    stopping_criteria=stopper.criteria,
)
answer = tokenizer.decode(answer[0])
print(answer)
###質問:
富士山の高さは?

###回答:
3776 m

###

もし、stop_tokens自体も要らなければ

answer = stopper.format(answer)
print(answer)

で、stopper.formatを使用することで

###質問:
富士山の高さは?

###回答:
3776 m

のように消すこともできます。

まとめ

今回は、Transformersにおけるテキスト生成の制御方法を詳しく見てきました。もし、生成結果が学習やプロンプトで制御がしきれないと感じたときは試してみてください〜

Spiral.AIテックブログ

Discussion