Huggingface Transformersに自分のモデルを追加してみた
こちらは「LLM・LLM活用 Advent Calendar 2024」の9日目の記事です!
概要
transformersにPull Requestを作成し、自分のモデルアーキテクチャを追加した。
はじめに
皆さんはあのモデルがtransformersにあればいいのに、もしくは画期的なアーキテクチャを開発して、みんなに使って欲しい、そんなことはありませんか?そういうわけでtransformersに新たなモデルを追加しましたので、記録に残します。
Transformersを読み解く
まずは自分のモデルを追加する前にtransformersを理解しましょう。とりあえずmodelに関わる部分だけを読めばいいので、例としてLlama Modelを見ていきましょうか。
llamaのモデル定義はsrc/transformers/models/llama
にあります。
src/transformers/models/llama/
├── __init__.py
├── convert_llama_weights_to_hf.py
├── configuration_llama.py
├── modeling_flax_llama.py
├── modeling_llama.py
├── tokenization_llama_fast.py
└── tokenization_llama.py
このうち必須なのはconfiguration_llama.py
とmodeling_llama.py
です。そのほかはMeta公式重みからの変換ツールやJax/Flaxのモデルファイルです。(transformersは実はPytorchだけではなくTFやJax/Flaxにも部分対応しています)
configuration_llama.py
このファイルではLlamaConfigを定義しています。そもそもtransformersはPretrainedConfig
やPreTrainedModel
などのGeneral Model Classがあります。これらのクラスはfrom_pretrained()
などのメソッドがすでに実装してあり、容易に保存、読み込みができます。そのため、これらを継承して実装するのがtransformersの基本です。
いつもモデルの事前学習している時に設定しているconfigと同じですね。モデルに必要なハイパーパラメータを用意しているだけです。
modeling_llama.py
(間に合わなかったので後日)
モデルの追加
How to add a model to 🤗 Transformers?に従って進めていきます。目標はLlamaモデルのAttentionをDifferential TransformerのAttentionに変更したDiffLlamaです。
環境構築
まずtransformersをForkしてローカルにクローンします。
# Forkした後
git clone https://github.com/[your Github handle]/transformers.git
cd transformers
git remote add upstream https://github.com/huggingface/transformers.git
次にPythonの仮想環境をセットアップします。
python -m venv .env
source .env/bin/activate
#おそらくエラーが出るので下のコマンドで代用
#pip install -e ".[dev]"
pip install -e ".[quality]"
モデルコードの追加
次に、ついに新しいコードを🤗 Transformersに追加できます。🤗 Transformersのフォークのクローンに移動してください
transformers-cli add-new-model-like
What is the model you would like to duplicate? Please provide the lowercase `model_type` (e.g. roberta): llama
What is the name (with no special casing) for your new model in the paper (e.g. RoBERTa)? DiffLlama
What identifier would you like to use for the `model_type` of this model? [diffllama]
What lowercase name would you like to use for the module (folder) of this model? [diffllama]
What prefix (camel-cased) would you like to use for the model classes of this model (e.g. Roberta)? [DiffLlama]
What prefix (upper-cased) would you like to use for the constants relative to this model? [DIFFLLAMA]
What will be the name of the config class for this model? [DiffLlamaConfig]
Please give a checkpoint identifier (on the model Hub) for this new model (e.g. facebook/FacebookAI/roberta-base):
Will your new model use the same processing class as llama (LlamaTokenizerFast) (yes/no)? yes
Should we add # Copied from statements when creating the new modeling file (yes/no)? [yes]
Should we add a version of your new model in all the frameworks implemented by llama (['pt']) (yes/no)? [yes]
The tests for symbolic tracing with torch.fx were disabled, you can add those once symbolic tracing works for your new model.
transformers リポジトリでプルリクエストを開く
自動生成されたコードを適応し始める前に、🤗 Transformers に「作業中(WIP)」プルリクエストを開くタイミングです。 例:「[WIP] Add DiffLlama」などです。 これにより、ユーザーと Hugging Face チームが🤗 Transformers にモデルを統合する作業を並行して行うことができます。
以下の手順を実行してください:
1. メインブランチから分かりやすい名前のブランチを作成します。
git checkout -b add_diffllama
2. 自動生成されたコードをコミットしてください:
git add .
git commit
3. 現在の main ブランチにフェッチしてリベース
git fetch upstream
git rebase upstream/main
4. 変更をあなたのアカウントにプッシュするには、次のコマンドを使用します:
git push -u origin add_diffllama
5. 満足したら、GitHub上のフォークのウェブページに移動します。[プルリクエスト]をクリックします。将来の変更に備えて、Hugging Face チームのメンバーのGitHubハンドルをレビュアーとして追加してください。
6. GitHubのプルリクエストウェブページの右側にある「ドラフトに変換」をクリックして、PRをドラフトに変更します。
PR内でHFの方が丁寧に教えてくれるので、困ったらメンションしてみましょう。(どこかにモデルの追加の時は@ArthurZuckerさんにとかリストがあったはずです)
コードゴリゴリタイム
src/transformers/models/diffllama/
にモデルファイルがありますが、そちらは編集せずにmodular_diffllama.py
を作成してください。Modular Transformersを使います。
使い方
モデルのクラスがほとんど同じ場合は(configのクラスだけの違いなど)次のように継承&pass
だけでOKです。
from ..gemma.modeling_gemma import GemmaForCausalLM
from ..llama.modeling_llama import (
LlamaDecoderLayer,
LlamaForQuestionAnswering,
LlamaForSequenceClassification,
LlamaForTokenClassification,
LlamaModel,
LlamaPreTrainedModel,
LlamaRMSNorm,
LlamaRotaryEmbedding,
apply_rotary_pos_emb,
repeat_kv,
)
from ..mistral.modeling_mistral import MistralMLP
from .configuration_diffllama import DiffLlamaConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "DiffLlamaConfig"
class DiffLlamaRMSNorm(LlamaRMSNorm):
pass
ALL_LAYERNORM_LAYERS.append(DiffLlamaRMSNorm)
class DiffLlamaRotaryEmbedding(LlamaRotaryEmbedding):
pass
class DiffLlamaMLP(MistralMLP):
pass
次のように実装する場合は普通に書けます。
Attention部分(めちゃ長い)
def lambda_init_fn(layer_idx):
return 0.8 - 0.6 * math.exp(-0.3 * layer_idx)
class DiffLlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: DiffLlamaConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
# under this are not used
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
self.lambda_init = lambda_init_fn(layer_idx)
self.lambda_q1 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
self.lambda_k1 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
self.lambda_q2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
self.lambda_k2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
self.groupnorm = nn.RMSNorm(2 * self.head_dim, eps=config.rms_norm_eps, elementwise_affine=False)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, target_len, _ = hidden_states.size()
q_len = target_len
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
value_states = torch.cat(torch.chunk(value_states, 2, dim=1), dim=-1)
value_states = value_states.repeat(1, 2, 1, 1)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
query_states.dtype
)
lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
query_states.dtype
)
lambda_full = lambda_1 - lambda_2 + self.lambda_init
attn_output = torch.matmul(attn_weights, value_states)
attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=1)
attn_output = attn_output1 - lambda_full * attn_output2
attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class DiffLlamaFlashAttention2(DiffLlamaAttention):
"""
DiffLlama flash attention module. This module inherits from `DiffLlamaAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if isinstance(past_key_value, StaticCache):
raise ValueError(
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
)
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.attention_dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (DiffLlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
value_states1, value_states2 = torch.chunk(value_states, 2, dim=2)
value_states1 = value_states1.repeat(1, 1, 2, 1)
value_states2 = value_states2.repeat(1, 1, 2, 1)
attn_output1 = _flash_attention_forward(
query_states,
key_states,
value_states1,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
)
attn_output2 = _flash_attention_forward(
query_states,
key_states,
value_states2,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
)
attn_output = torch.cat([attn_output1, attn_output2], dim=-1)
attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=2)
lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
query_states.dtype
)
lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
query_states.dtype
)
lambda_full = lambda_1 - lambda_2 + self.lambda_init
attn_output = attn_output1 - lambda_full * attn_output2
attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class DiffLlamaSdpaAttention(DiffLlamaAttention):
"""
DiffLlama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`DiffLlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
# Adapted from DiffLlamaAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"DiffLlamaModel is using DiffLlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
value_states = torch.cat(torch.chunk(value_states, 2, dim=1), dim=-1)
value_states = value_states.repeat(1, 2, 1, 1)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
)
attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=1)
lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
query_states.dtype
)
lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
query_states.dtype
)
lambda_full = lambda_1 - lambda_2 + self.lambda_init
attn_output = attn_output1 - lambda_full * attn_output2
attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
modular_diffllama.py
を完成させた後、python utils/modular_model_converter.py --files_to_parse src/transformers/models/diffllama/modular_diffllama.py
でmodeling_diffllama.py
自動生成できます。
モデルのテスト
この時点で、新しいモデルが正常に追加されました。 ただし、モデルがまだ必要な設計に完全に準拠していない可能性が非常に高いです。 🤗 Transformersと完全に互換性があることを確認するために、すべての一般的なテストがパスする必要があります。 Cookiecutterはおそらくモデル用のテストファイルを自動的に追加しているはずで、おそらく同じディレクトリにtests/models/diffllama/test_modeling_diffllama.pyとして存在します。 このテストファイルを実行して、すべての一般的なテストがパスすることを確認してください:
sh pytest tests/models/diffllama/test_modeling_diffllama.py
ドキュメントの追加
よくわからないので11. ドキュメントの追加を参考
コードのリファクタリング
最後にコードを綺麗にします。
make style
make quality
ここまでできたらあとはPR内の指示に従いましょう。
終わりに
あなたはコミュニティの誰でも簡単にアクセスできる別のモデルを作成しました! 🤯
私はDifferental Transformerの論文に見聞きしてからPRを作成したのが10月11日で、別のこともやりながらようやく数日前に最終レビュー(であって欲しい)に辿り着きました。モデルを一から開発するよりは簡単でしょうが、すごく大変でした。ただ、モデルを追加しているうちに普段使っているtransformersを理解することができました。ぜひ、皆さんも自作のモデルアーキテクチャを追加してみてください。
Discussion