transformers で複数のトークナイザーを一つのプロセッサーで扱う
この記事は、LLM・LLM活用 Advent Calendar 2024 の 10 日目の記事になります。
はじめに
テキストを生成するモデルである大規模言語モデル (LLM) は transformers などのライブラリで簡単に扱えるようになりました。
また、近年はテキストだけでなく、画像や動画をもとにテキストを生成できるマルチモーダルなモデルも増えており、その際にテキストや画像を機械学習モデルへ入力するための前処理は段々と複雑になってきています。
テキストをトークンに分割する処理や画像をパッチに分割する処理などの前処理担当を、プリプロセッサーやプロセッサーと呼びますが、transformers では、テキストと画像を一緒に前処理するプロセッサーの一例として LlavaProcessorが実装されています。
LlavaProcessor では画像認識モデル用の CIPImageProcessor と テキストトークナイズ用の LlavaTokenizer をまとめて扱うことができますが、このようなプロセッサーをカスタムする方法はあまり紹介されていないので紹介したいと思います。
今回は、特殊な例として 二つの異なるトークナイザーを一つのプロセッサーで扱う方法 と、そのプロセッサーを簡単に push_to_hub()
や from_pretrained()
で 保存・読み込みできるように する方法を紹介します。この方法を応用することで、好きなだけトークナイザーやプロセッサーをまとめて管理することができるようになります。
カスタムコードで独自のモデルを transformers で扱う方法については、LLM・LLM活用 Advent Calendar 2024 の9日目の記事 Huggingface Transformersに自分のモデルを追加してみた!@weak_kajuma を先に読んでおくと良いかもしれません。
プロセッサーの例
二つのトークナイザーを内包するプロセッサーの例を示します。
実際に動作する例を HuggingFace Hub にアップしています:
ディレクトリ構成
./
├── models
│ ├── __init__.py
│ └── processor_multi.py <- プロセッサー本体
...
models/__init__.py
は空ファイルですが、カスタムコード登録の関係で必要になります。
実装
models/processor_multi.py
では MultiProcessor
という名前でプロセッサーを実装しています。これは、基本的に LlavaProcessor
と同様の処理になっていますが、複数のトークナイザーを扱うために幾つか変更 が加えられています。
コード本体
import os
import json
import warnings
from pathlib import Path
import torch
import torch.nn as nn
from transformers import (
PreTrainedTokenizer,
PreTrainedTokenizerBase,
ProcessorMixin,
BatchFeature,
)
from transformers.utils import (
logging,
direct_transformers_import,
PROCESSOR_NAME,
CHAT_TEMPLATE_NAME,
)
from transformers.image_utils import ImageInput
from transformers.dynamic_module_utils import custom_object_save
logger = logging.get_logger(__name__)
# Dynamically import the Transformers module to grab the attribute classes of the processor form their names.
transformers_module = direct_transformers_import(Path(__file__).parent)
# それぞれのトークナイザーに渡す用のデフォルト引数の定義
class MultiProcessorKwargs:
_defaults = {
"tokenizer_1_kwargs": {
"padding": False,
},
"tokenizer_2_kwargs": {
"padding": False,
},
}
# LlavaProcessor ベースのプロセッサー
class MultiProcessor(ProcessorMixin):
attributes = ["tokenizer_1", "tokenizer_2"] # ここでプロセッサーが持つプロセッサーを指定
valid_kwargs = ["chat_template"]
tokenizer_1_class = "AutoTokenizer" # それぞれのプロセッサーの `from_pretrained`する時のクラスを指定
tokenizer_2_class = "AutoTokenizer"
tokenizer_1: PreTrainedTokenizer # それぞれのプロセッサーの型を指定
tokenizer_2: PreTrainedTokenizer
def __init__(
self,
tokenizer_1=None, # プロセッサーを作成するときにトークナイザーを渡す
tokenizer_2=None,
chat_template=None,
**kwargs,
):
super().__init__(
tokenizer_1, # super().__init__ に渡してあげる
tokenizer_2,
chat_template=chat_template,
**kwargs,
)
# __call__ で定義することで processor(text_1="テキスト", text_2="テキスト") で呼び出せる
def __call__(
self,
text_1: str | list[str] | None = None, # 一つ目のトークナイザー用
text_2: str | list[str] | None = None, # 二つ目のトークナイザー用
**kwargs,
) -> BatchFeature:
# ただの方チェック#
def _validate_text_input(text) -> str | list[str]:
if isinstance(text, list):
assert all(
isinstance(t, str) for t in text
), f"Expected list of str but got {type(text)}"
assert all(len(t) > 0 for t in text), "Expected non-empty strings"
else:
assert isinstance(text, str), f"Expected str but got {type(text)}"
return text
def _normalize_text_input(text: str | list[str]) -> list[str]:
if isinstance(text, str):
return [text]
return text
# ここは型を list[str] に揃えてるだけ
_text_1: str | list[str] = _validate_text_input(text_1)
text_1_list: list[str] = _normalize_text_input(_text_1)
_text_2: str | list[str] = _validate_text_input(text_2)
text_2_list: list[str] = _normalize_text_input(_text_2)
# デフォの引数を MultiProcessorKwargs から引っ張ってきてるが方法はなんでもいい
# kwargs と統合することで、オプションで上書きできる
tokenizer_1_output_kwargs = {
**MultiProcessorKwargs._defaults["tokenizer_1_kwargs"],
"return_tensors": "pt",
**kwargs,
}
tokenizer_2_output_kwargs = {
**MultiProcessorKwargs._defaults["tokenizer_2_kwargs"],
"return_tensors": "pt",
**kwargs,
}
# それぞれトークナイズする
text_1_inputs = self.tokenizer_1(
text_1_list,
**tokenizer_1_output_kwargs,
)
text_2_inputs = self.tokenizer_2(
text_2_list,
**tokenizer_2_output_kwargs,
)
# BatchFeature は出力をいい感じに扱えるようにしてくれるやつ
return BatchFeature(
data={
"input_ids": text_1_inputs.get("input_ids"),
"attention_mask": text_1_inputs.get("attention_mask"),
# 二つ目のトークナイズ結果を追加
"input_ids_2": text_2_inputs.get("input_ids"),
"attention_mask_2": text_2_inputs.get("attention_mask"),
}
)
# デコード時は二つ目のトークナイザーだけでデコード
def batch_decode(self, *args, **kwargs):
return self.tokenizer_2_tokenizer.batch_decode(*args, **kwargs)
# デコード時は二つ目のトークナイザーだけでデコード
def decode(self, *args, **kwargs):
return self.tokenizer_2_tokenizer.decode(*args, **kwargs)
# プロセッサーの引数名
@property
def model_input_names(self):
return ["text_1", "text_2"]
# プロセッサーを保存するためのトリック
# ベース: https://github.com/huggingface/transformers/blob/1d063793318b20654ebb850f48f43e0a247ab7bb/src/transformers/processing_utils.py#L980-L995
@classmethod
def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
args = []
for attribute_name in cls.attributes:
class_name = getattr(cls, f"{attribute_name}_class")
subfolder = attribute_name # subfolder is the same as attribute_name
if isinstance(class_name, tuple):
classes = tuple(
getattr(transformers_module, n) if n is not None else None
for n in class_name
)
use_fast = kwargs.get("use_fast", True)
if use_fast and classes[1] is not None:
attribute_class = classes[1]
else:
attribute_class = classes[0]
else:
attribute_class = getattr(transformers_module, class_name)
assert attribute_class is not None, f"Missing attribute class: {class_name}"
args.append(
attribute_class.from_pretrained(
pretrained_model_name_or_path,
subfolder=subfolder,
**kwargs,
)
)
return args
# プロセッサーの保存
# ベース: https://github.com/huggingface/transformers/blob/1d063793318b20654ebb850f48f43e0a247ab7bb/src/transformers/processing_utils.py#L460-L560
def save_pretrained(self, save_directory, push_to_hub: bool = False, **kwargs):
use_auth_token = kwargs.pop("use_auth_token", None)
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
FutureWarning,
)
if kwargs.get("token", None) is not None:
raise ValueError(
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
)
kwargs["token"] = use_auth_token
os.makedirs(save_directory, exist_ok=True)
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory)
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
# loaded from the Hub.
if self._auto_class is not None:
attrs = [
getattr(self, attribute_name) for attribute_name in self.attributes
]
configs = [
(a.init_kwargs if isinstance(a, PreTrainedTokenizerBase) else a)
for a in attrs
]
configs.append(self)
custom_object_save(self, save_directory, config=configs)
for attribute_name in self.attributes:
attribute = getattr(self, attribute_name)
# Include the processor class in the attribute config so this processor can then be reloaded with the
# `AutoProcessor` API.
if hasattr(attribute, "_set_processor_class"):
attribute._set_processor_class(self.__class__.__name__)
attribute.save_pretrained(
os.path.join(
save_directory,
attribute_name, # CHANGED: save to subfolder
),
)
if self._auto_class is not None:
# We added an attribute to the init_kwargs of the tokenizers, which needs to be cleaned up.
for attribute_name in self.attributes:
attribute = getattr(self, attribute_name)
if isinstance(attribute, PreTrainedTokenizerBase):
del attribute.init_kwargs["auto_map"]
# If we save using the predefined names, we can load using `from_pretrained`
# plus we save chat_template in its own file
output_processor_file = os.path.join(save_directory, PROCESSOR_NAME)
output_chat_template_file = os.path.join(save_directory, CHAT_TEMPLATE_NAME)
processor_dict = self.to_dict()
# Save `chat_template` in its own file. We can't get it from `processor_dict` as we popped it in `to_dict`
# to avoid serializing chat template in json config file. So let's get it from `self` directly
if self.chat_template is not None:
chat_template_json_string = (
json.dumps(
{"chat_template": self.chat_template}, indent=2, sort_keys=True
)
+ "\n"
)
with open(output_chat_template_file, "w", encoding="utf-8") as writer:
writer.write(chat_template_json_string)
logger.info(f"chat template saved in {output_chat_template_file}")
# For now, let's not save to `processor_config.json` if the processor doesn't have extra attributes and
# `auto_map` is not specified.
if set(processor_dict.keys()) != {"processor_class"}:
self.to_json_file(output_processor_file)
logger.info(f"processor saved in {output_processor_file}")
if push_to_hub:
self._upload_modified_files(
save_directory,
repo_id,
files_timestamps,
commit_message=commit_message,
token=kwargs.get("token"),
)
if set(processor_dict.keys()) == {"processor_class"}:
return []
return [output_processor_file]
コメントにも書いてありますが、実装されているクラス・関数はそれぞれ次のような役割です:
-
MultiProcessorKwargs
: トークナイザーに渡すデフォルト引数を定義-
LlavaProcessor
で使われているデフォルト引数の扱いと若干異なりますが、最終的に同じことが実現できればいいのでここの形式は重要ではないです
-
-
MultiProcessor
: プロセッサー本体のクラス- 変数
-
attributes
: プロセッサーが持つ子トークナイザの名前を指定 -
valid_kwargs
: (よくわかってない) おそらく前処理実行時に受け取れる引数名? -
tokenizer_1_class
,tokenizer_2_class
: それぞれのトークナイザーのインスタンス化に使うクラス名。 -
tokenizer_1
,tokenizer_2
: それぞれのトークナイザーのインスタンス- それぞれのトークナイザーのクラス指定には
attributes
で指定した名前に揃える必要があります
- それぞれのトークナイザーのクラス指定には
-
- 関数
-
__init__
: プロセッサーの初期化 -
__call__
:processor(text_1="テキスト", text_2="テキスト")
で呼び出されるときの関数。ここで受け取れる引数を指定&前処理実行 -
batch_decode
: 複数のシーケンスをまとめてデコードする関数- ここでは二つ目のトークナイザーだけでデコードしています
-
decode
: 一つのシーケンスをデコードする関数- こちらも同様に二つ目のトークナイザーだけでデコードしています
-
model_input_names
: (よくわかってない) おそらく__call__
するときに受け取れる引数名? -
_get_arguments_from_pretrained
:save_pretrained
する際に呼ばれる、子トークナイザーをディスクに保存する処理 -
save_pretrained
: ローカルにプロセッサーを保存する処理。通常の実装では複数のトークナイザーを持つと正しく保存できないので、修正した処理に変更しています
-
- 変数
使用例
以下の二つのトークナイザーを持たせてみます
- 一つ目のトークナイザー (
tokenizer_1
): llm-jp/llm-jp-3-1.8b のトークナイザー - 二つ目のトークナイザー (
tokenizer_2
): Qwen/QwQ-32B-Preview のトークナイザー
これらのトークナイザーは、AutoTokenizer
で読み込めて PreTrainedTokenizer
として扱えるのならば何でも使えます。
コードではこのようになります:
from transformers import AutoTokenizer, AutoProcessor
from models.processor_multi import MultiProcessor
# push_to_hub 用に AutoProcessor に登録
MultiProcessor.register_for_auto_class("AutoProcessor")
# プロセッサーを作成
processor = MultiProcessor(
tokenizer_1=AutoTokenizer.from_pretrained("llm-jp/llm-jp-3-1.8b"),
tokenizer_2=AutoTokenizer.from_pretrained("Qwen/QwQ-32B-Preview"),
)
# エンコード
print(processor(
text_1="テキスト1",
text_2="テキスト2",
))
# {'input_ids': tensor([[ 1, 43412, 28745]]), 'attention_mask': tensor([[1, 1, 1]]), 'input_ids_2': tensor([[56833, 61803, 70534, 17]]), 'attention_mask_2': tensor([[1, 1, 1, 1]])}
# push_to_hub で huggingface hub にアップロード
processor.push_to_hub(MY_REPO_NAME, private=True)
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained(
MY_REPO_NAME,
trust_remote_code=True, # カスタムコードなので必要
)
複数のトークナイザーを持ったまま保存するための小技
transformers では、非常に簡単に save_pretrained()
, push_to_hub()
, from_pretrained()
などで保存やhubへのアップロード、読み込みができるわけですが、裏側でどんな処理が行われているかを考えたことがあるでしょうか?
プロセッサーで push_to_hub()
が呼び出さると、以下の図のような処理が行われます。
attributes
や attribute
は、コード中におけるプロセッサーが持っている子トークナイザー・プロセッサーのことです。
push_to_hub()
は大まかな流れとして、 save_pretrained()
してから huggingface_hub.hf_api を用いて、保存されたファイルをすべて HuggingFace Hub にアップロードするという形になります。
そのため、適切に push_to_hub
するためには save_pretrained
でローカルディスクに書き込む処理が正常に行えれば良いことがわかります。
プロセッサーにおける from_pretrained
は以下ののうになります:
ただし、既存の ProcessorMixin は現時点 (2024/12/04) では、 save_pretrained
、from_pretrained
する際に サブディレクトリを指定せず、保存ディレクトリ直下にすべてを保存してしまう ため、一つのプロセッサーが複数のトークナイザー・プロセッサーを持つと同じディレクトリに保存・上書きしてしまうため、競合してしまいます。実際に保存処理を修正せずに save_pretrained
、from_pretrained
をすると、二つとも同じトークナイザー(一番最後に保存された方)が使われてしまいます。
そこで、保存・読み込み処理を少し調整し、それぞれの子トークナイザーをサブディレクトリに保存するように変更することで、好きなだけたくさんのトークナイザー持ったプロセッサーを作成できるようになります。トークナイザーだけでなく画像のプロセッサーについても同様です。
注意点
カスタムコードを Huggingface Hub にアップする際に AutoProcessor
に register
する必要があるわけですが、この登録処理では クラス名に依存して処理が若干変わる罠 があり、それによって適切に from_pretraiend
できなくなることがあります。
具体的には、 save_pretrained
の途中で呼び出される custom_object_save
関数が保存するカスタムプロセッサーのクラス名に基づいた処理をしており、Tokenizer
が名前に含まれていると auto_map
に一般のトークナイザー用のオプションである slow_tokenizer
の枠を用意してしまい、from_pretrained
する際に auto_class への引数に余計なものが増えて正常に読み込めなくなります。 そのため、カスタムプロセッサーのクラス名には tokenizer
を含めないようにする必要があります。
当該処理:
まとめ
本記事では transformers ライブラリのカスタムプロセッサーとして、複数のトークナイザーを競合しないように扱う方法を紹介しました。
今後の LLM・LLM活用 Advent Calendar 2024 もお楽しみください。
関連情報
Discussion