🤗

Huggingface Transformersに自分のモデルを追加してみた

2024/12/09に公開

こちらは「LLM・LLM活用 Advent Calendar 2024」の9日目の記事です!
https://qiita.com/advent-calendar/2024/large-language-model

概要

transformersにPull Requestを作成し、自分のモデルアーキテクチャを追加した。

はじめに

皆さんはあのモデルがtransformersにあればいいのに、もしくは画期的なアーキテクチャを開発して、みんなに使って欲しい、そんなことはありませんか?そういうわけでtransformersに新たなモデルを追加しましたので、記録に残します。

Transformersを読み解く

まずは自分のモデルを追加する前にtransformersを理解しましょう。とりあえずmodelに関わる部分だけを読めばいいので、例としてLlama Modelを見ていきましょうか。

llamaのモデル定義はsrc/transformers/models/llamaにあります。

src/transformers/models/llama/
├── __init__.py
├── convert_llama_weights_to_hf.py
├── configuration_llama.py
├── modeling_flax_llama.py
├── modeling_llama.py
├── tokenization_llama_fast.py
└── tokenization_llama.py

このうち必須なのはconfiguration_llama.pymodeling_llama.pyです。そのほかはMeta公式重みからの変換ツールやJax/Flaxのモデルファイルです。(transformersは実はPytorchだけではなくTFやJax/Flaxにも部分対応しています)

configuration_llama.py

このファイルではLlamaConfigを定義しています。そもそもtransformersPretrainedConfigPreTrainedModelなどのGeneral Model Classがあります。これらのクラスはfrom_pretrained()などのメソッドがすでに実装してあり、容易に保存、読み込みができます。そのため、これらを継承して実装するのがtransformersの基本です。

https://github.com/huggingface/transformers/blob/329f5dbf97a5cb2473914c88c05aa3dcb242e19a/src/transformers/models/llama/configuration_llama.py#L155-L216

いつもモデルの事前学習している時に設定しているconfigと同じですね。モデルに必要なハイパーパラメータを用意しているだけです。

modeling_llama.py

(間に合わなかったので後日)

モデルの追加

How to add a model to 🤗 Transformers?に従って進めていきます。目標はLlamaモデルのAttentionをDifferential TransformerのAttentionに変更したDiffLlamaです。

環境構築

まずtransformersをForkしてローカルにクローンします。

# Forkした後
git clone https://github.com/[your Github handle]/transformers.git
cd transformers
git remote add upstream https://github.com/huggingface/transformers.git

次にPythonの仮想環境をセットアップします。

python -m venv .env
source .env/bin/activate
#おそらくエラーが出るので下のコマンドで代用
#pip install -e ".[dev]"
pip install -e ".[quality]"

モデルコードの追加

次に、ついに新しいコードを🤗 Transformersに追加できます。🤗 Transformersのフォークのクローンに移動してください

transformers-cli add-new-model-like

What is the model you would like to duplicate? Please provide the lowercase `model_type` (e.g. roberta): llama
What is the name (with no special casing) for your new model in the paper (e.g. RoBERTa)? DiffLlama
What identifier would you like to use for the `model_type` of this model?  [diffllama] 
What lowercase name would you like to use for the module (folder) of this model?  [diffllama] 
What prefix (camel-cased) would you like to use for the model classes of this model (e.g. Roberta)?  [DiffLlama] 
What prefix (upper-cased) would you like to use for the constants relative to this model?  [DIFFLLAMA] 
What will be the name of the config class for this model?  [DiffLlamaConfig] 
Please give a checkpoint identifier (on the model Hub) for this new model (e.g. facebook/FacebookAI/roberta-base): 
Will your new model use the same processing class as llama (LlamaTokenizerFast) (yes/no)? yes
Should we add # Copied from statements when creating the new modeling file (yes/no)?  [yes]    
Should we add a version of your new model in all the frameworks implemented by llama (['pt']) (yes/no)?  [yes] 
The tests for symbolic tracing with torch.fx were disabled, you can add those once symbolic tracing works for your new model.

transformers リポジトリでプルリクエストを開く

自動生成されたコードを適応し始める前に、🤗 Transformers に「作業中(WIP)」プルリクエストを開くタイミングです。 例:「[WIP] Add DiffLlama」などです。 これにより、ユーザーと Hugging Face チームが🤗 Transformers にモデルを統合する作業を並行して行うことができます。

以下の手順を実行してください:
1. メインブランチから分かりやすい名前のブランチを作成します。

git checkout -b add_diffllama

2. 自動生成されたコードをコミットしてください:

git add .
git commit

3. 現在の main ブランチにフェッチしてリベース

git fetch upstream
git rebase upstream/main

4. 変更をあなたのアカウントにプッシュするには、次のコマンドを使用します:

git push -u origin add_diffllama

5. 満足したら、GitHub上のフォークのウェブページに移動します。[プルリクエスト]をクリックします。将来の変更に備えて、Hugging Face チームのメンバーのGitHubハンドルをレビュアーとして追加してください。

6. GitHubのプルリクエストウェブページの右側にある「ドラフトに変換」をクリックして、PRをドラフトに変更します。

PR内でHFの方が丁寧に教えてくれるので、困ったらメンションしてみましょう。(どこかにモデルの追加の時は@ArthurZuckerさんにとかリストがあったはずです)

コードゴリゴリタイム

src/transformers/models/diffllama/にモデルファイルがありますが、そちらは編集せずにmodular_diffllama.pyを作成してください。Modular Transformersを使います。

使い方

モデルのクラスがほとんど同じ場合は(configのクラスだけの違いなど)次のように継承&passだけでOKです。

from ..gemma.modeling_gemma import GemmaForCausalLM
from ..llama.modeling_llama import (
    LlamaDecoderLayer,
    LlamaForQuestionAnswering,
    LlamaForSequenceClassification,
    LlamaForTokenClassification,
    LlamaModel,
    LlamaPreTrainedModel,
    LlamaRMSNorm,
    LlamaRotaryEmbedding,
    apply_rotary_pos_emb,
    repeat_kv,
)
from ..mistral.modeling_mistral import MistralMLP
from .configuration_diffllama import DiffLlamaConfig


logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "DiffLlamaConfig"


class DiffLlamaRMSNorm(LlamaRMSNorm):
    pass


ALL_LAYERNORM_LAYERS.append(DiffLlamaRMSNorm)


class DiffLlamaRotaryEmbedding(LlamaRotaryEmbedding):
    pass


class DiffLlamaMLP(MistralMLP):
    pass

次のように実装する場合は普通に書けます。

Attention部分(めちゃ長い)
def lambda_init_fn(layer_idx):
    return 0.8 - 0.6 * math.exp(-0.3 * layer_idx)


class DiffLlamaAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config: DiffLlamaConfig, layer_idx: Optional[int] = None):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        if layer_idx is None:
            logger.warning_once(
                f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
                "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
                "when creating this class."
            )

        self.attention_dropout = config.attention_dropout
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        # under this are not used
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.is_causal = True

        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)

        self.lambda_init = lambda_init_fn(layer_idx)
        self.lambda_q1 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
        self.lambda_k1 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
        self.lambda_q2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
        self.lambda_k2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
        self.groupnorm = nn.RMSNorm(2 * self.head_dim, eps=config.rms_norm_eps, elementwise_affine=False)

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: Tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, target_len, _ = hidden_states.size()
        q_len = target_len

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)
        value_states = torch.cat(torch.chunk(value_states, 2, dim=1), dim=-1)
        value_states = value_states.repeat(1, 2, 1, 1)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
        lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
            query_states.dtype
        )
        lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
            query_states.dtype
        )
        lambda_full = lambda_1 - lambda_2 + self.lambda_init

        attn_output = torch.matmul(attn_weights, value_states)
        attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=1)

        attn_output = attn_output1 - lambda_full * attn_output2
        attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(bsz, q_len, -1)

        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value


class DiffLlamaFlashAttention2(DiffLlamaAttention):
    """
    DiffLlama flash attention module. This module inherits from `DiffLlamaAttention` as the weights of the module stays
    untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
    flash attention and deal with padding tokens in case the input contains any of them.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
        # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
        # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
        self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: Tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        if isinstance(past_key_value, StaticCache):
            raise ValueError(
                "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
                "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
            )

        output_attentions = False

        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        # Flash attention requires the input to have the shape
        # batch_size x seq_length x head_dim x hidden_dim
        # therefore we just need to keep the original shape
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        if position_embeddings is None:
            logger.warning_once(
                "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
                "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
                "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
                "removed and `position_embeddings` will be mandatory."
            )
            cos, sin = self.rotary_emb(value_states, position_ids)
        else:
            cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
        # to be able to avoid many of these transpose/reshape/view.
        query_states = query_states.transpose(1, 2)
        key_states = key_states.transpose(1, 2)
        value_states = value_states.transpose(1, 2)

        dropout_rate = self.attention_dropout if self.training else 0.0

        # In PEFT, usually we cast the layer norms in float32 for training stability reasons
        # therefore the input hidden states gets silently casted in float32. Hence, we need
        # cast them back in the correct dtype just to be sure everything works as expected.
        # This might slowdown training & inference so it is recommended to not cast the LayerNorms
        # in fp32. (DiffLlamaRMSNorm handles it correctly)

        input_dtype = query_states.dtype
        if input_dtype == torch.float32:
            if torch.is_autocast_enabled():
                target_dtype = torch.get_autocast_gpu_dtype()
            # Handle the case where the model is quantized
            elif hasattr(self.config, "_pre_quantization_dtype"):
                target_dtype = self.config._pre_quantization_dtype
            else:
                target_dtype = self.q_proj.weight.dtype

            logger.warning_once(
                f"The input hidden states seems to be silently casted in float32, this might be related to"
                f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
                f" {target_dtype}."
            )

            query_states = query_states.to(target_dtype)
            key_states = key_states.to(target_dtype)
            value_states = value_states.to(target_dtype)

        value_states1, value_states2 = torch.chunk(value_states, 2, dim=2)
        value_states1 = value_states1.repeat(1, 1, 2, 1)
        value_states2 = value_states2.repeat(1, 1, 2, 1)

        attn_output1 = _flash_attention_forward(
            query_states,
            key_states,
            value_states1,
            attention_mask,
            q_len,
            position_ids=position_ids,
            dropout=dropout_rate,
            sliding_window=getattr(self, "sliding_window", None),
            use_top_left_mask=self._flash_attn_uses_top_left_mask,
            is_causal=self.is_causal,
        )

        attn_output2 = _flash_attention_forward(
            query_states,
            key_states,
            value_states2,
            attention_mask,
            q_len,
            position_ids=position_ids,
            dropout=dropout_rate,
            sliding_window=getattr(self, "sliding_window", None),
            use_top_left_mask=self._flash_attn_uses_top_left_mask,
            is_causal=self.is_causal,
        )

        attn_output = torch.cat([attn_output1, attn_output2], dim=-1)
        attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=2)

        lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
            query_states.dtype
        )
        lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
            query_states.dtype
        )
        lambda_full = lambda_1 - lambda_2 + self.lambda_init

        attn_output = attn_output1 - lambda_full * attn_output2
        attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
        attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value


class DiffLlamaSdpaAttention(DiffLlamaAttention):
    """
    DiffLlama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
    `DiffLlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
    SDPA API.
    """

    # Adapted from DiffLlamaAttention.forward
    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: Tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        if output_attentions:
            # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
            logger.warning_once(
                "DiffLlamaModel is using DiffLlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
                'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
            )
            return super().forward(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
                position_embeddings=position_embeddings,
            )

        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)
        value_states = torch.cat(torch.chunk(value_states, 2, dim=1), dim=-1)
        value_states = value_states.repeat(1, 2, 1, 1)

        causal_mask = attention_mask
        if attention_mask is not None:
            causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]

        # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
        # Reference: https://github.com/pytorch/pytorch/issues/112577.
        if query_states.device.type == "cuda" and causal_mask is not None:
            query_states = query_states.contiguous()
            key_states = key_states.contiguous()
            value_states = value_states.contiguous()

        # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
        # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
        is_causal = True if causal_mask is None and q_len > 1 else False

        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=causal_mask,
            dropout_p=self.attention_dropout if self.training else 0.0,
            is_causal=is_causal,
        )

        attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=1)

        lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
            query_states.dtype
        )
        lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
            query_states.dtype
        )
        lambda_full = lambda_1 - lambda_2 + self.lambda_init

        attn_output = attn_output1 - lambda_full * attn_output2
        attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(bsz, q_len, -1)
        attn_output = self.o_proj(attn_output)

        return attn_output, None, past_key_value

modular_diffllama.pyを完成させた後、python utils/modular_model_converter.py --files_to_parse src/transformers/models/diffllama/modular_diffllama.pymodeling_diffllama.py自動生成できます。

モデルのテスト

この時点で、新しいモデルが正常に追加されました。 ただし、モデルがまだ必要な設計に完全に準拠していない可能性が非常に高いです。 🤗 Transformersと完全に互換性があることを確認するために、すべての一般的なテストがパスする必要があります。 Cookiecutterはおそらくモデル用のテストファイルを自動的に追加しているはずで、おそらく同じディレクトリにtests/models/diffllama/test_modeling_diffllama.pyとして存在します。 このテストファイルを実行して、すべての一般的なテストがパスすることを確認してください:
sh pytest tests/models/diffllama/test_modeling_diffllama.py

ドキュメントの追加

よくわからないので11. ドキュメントの追加を参考

コードのリファクタリング

最後にコードを綺麗にします。

make style
make quality

ここまでできたらあとはPR内の指示に従いましょう。

終わりに

あなたはコミュニティの誰でも簡単にアクセスできる別のモデルを作成しました! 🤯
私はDifferental Transformerの論文に見聞きしてからPRを作成したのが10月11日で、別のこともやりながらようやく数日前に最終レビュー(であって欲しい)に辿り着きました。モデルを一から開発するよりは簡単でしょうが、すごく大変でした。ただ、モデルを追加しているうちに普段使っているtransformersを理解することができました。ぜひ、皆さんも自作のモデルアーキテクチャを追加してみてください。

Discussion