📚

BERTの実装を分解する

2023/02/04に公開

以下の書籍を読んだ際にBERTの実装をコメントを付けながら読み進めていきましたので、まとめます。新卒で入社した会社で深層学習を勉強していたときに大変お世話になった書籍です。自信を持っておすすめします。
https://www.amazon.co.jp/dp/4839970254/
今回はモデルの説明だけでも説明量が多い関係から、optimizerやDatasetLoader、推論部の説明は割愛します。ご了承ください。

BERTとは

Bidirectional Encoder Representations from Transformersの略です。Transformerからなる双方向エンコーダによる表現学習の手法です。論文の解説は以下のような詳細な解説記事があるため割愛します。
https://qiita.com/omiita/items/72998858efc19a368e50

動作環境

  • Ubuntu20.04
  • Python 3.8.10
  • torch 1.7.1

実装

BERTクラスを実装します.最終的に実装するBERTクラスは複数のクラスから構成されています。先に構成クラスの構造を図解します。huggingfaceで2018年時点で公開されていたpretrained-transforrmersを基に実装されています。

  • class BERT
    • class BertEmbeddings:単語・文情報埋め込み部
      • class BertLayerNorm
    • class BertEncoder:単語ベクトルのエンコード部
      • class BertLayer × N層
        • class BertAttention:self-attention計算部
          • class BertSelfAttention
          • class BertSelfOutput
            • class BertLayerNorm
        • class BertIntermediate:Feed-Forward Network1計算部
        • class BertOutput:Feed-Forward Network2計算部
          • class BertLayerNorm
    • class BertPooler:特徴ベクトルの抽出部

全体実装 model.py 521行
import math
import torch
import torch.nn.functional as F
from torch import nn


class BertLayerNorm(nn.Module):
    """レイヤ正規化クラス"""
    def __init__(self, hidden_size: float, eps: float=1e-6):
        super(BertLayerNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(hidden_size))
        self.beta  = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps
        return

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """レイヤ正規化

        Args:
            x (torch.Tensor): 残差接続された後の特徴ベクトル(Post-LN方式)

        Returns:
            torch.Tensor: レイヤ正規化後特徴ベクトル
        """
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.variance_epsilon)
        return self.gamma * x + self.beta


class BertEmbeddings(nn.Module):
    '''単語埋め込みクラス'''
    def __init__(self, config: dict):
        super(BertEmbeddings, self).__init__()

        self.config = config

        self.src_word_enc = nn.Embedding(
            self.config.vocab_size,
            self.config.hidden_size,
            padding_idx=2)

        self.src_pos_enc = nn.Embedding(
            config.max_position_embeddings,
            config.hidden_size,
            padding_idx=2)

        self.token_type_embeddings = nn.Embedding(
            config.type_vocab_size,
            config.hidden_size)

        self.LayerNorm = BertLayerNorm(
            config.hidden_size,
            eps=1e-6)

        self.dropout = nn.Dropout(
            config.hidden_dropout_prob)
        return

    def forward(self, input_ids: torch.Tensor, token_type_ids: torch.Tensor=None) -> torch.Tensor:
        """単語埋め込み

        Args:
            input_ids (torch.Tensor): idに変換した入力単語
            token_type_ids (torch.Tensor, optional): _description_. Defaults to None.

        Returns:
            torch.Tensor: 単語埋め込みベクトル
        """
        # 単語埋め込み
        src_word_emb = self.src_word_enc(input_ids)

        # 埋め込み後ベクトルの中間層の次元で調整
        # ※Transformerの原著論文で実装されているが根拠不明
        if self.config.scale_emb:
            src_word_emb *= self.config.hidden_size ** 0.5

        # 文章情報埋め込み
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        # 単語位置情報埋め込み
        position_ids = torch.arange(
            input_ids.size(1),
            dtype=torch.long,
            device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
        src_pos_emb = self.src_pos_enc(position_ids)

        # 全埋め込み結果を加算しレイヤ正規化とドロップアウトを適用
        embeddings = src_word_emb + src_pos_emb + token_type_embeddings
        embeddings = self.dropout(self.LayerNorm(embeddings))

        return embeddings


class BertLayer(nn.Module):
    '''Self-Attention + Feed-Forwardクラス'''
    def __init__(self, config: dict):
        super(BertLayer, self).__init__()

        self.src_attn = BertAttention(config)
        self.intrmed = BertIntermediate(config)
        self.src_out = BertOutput(config)
        return

    def forward(self, src_hidden: torch.Tensor, src_attn_mask: torch.Tensor=None, src_attn_flg=False) -> torch.Tensor:
        """self-attention層+全結合層+残差接続

        Args:
            src_hidden (torch.Tensor): 特徴ベクトル
            src_attn_mask (_type_, optional): マスク. Defaults to None.
            src_attn_flg (bool, optional): attentionマップを出力するか否か. Defaults to False.

        Returns:
            torch.Tensor: self-attentionを通過した特徴ベクトル
        """
        # self-attentionの計算,attention mapも返す
        if src_attn_flg:
            src_attn_out, src_attn_prb = self.src_attn(
                src_hidden,
                src_attn_mask,
                src_attn_flg)

            # Feed Forward,残差接続2
            intermediate_output = self.intrmed(src_attn_out)

            src_layer_output = self.src_out(
                intermediate_output,
                src_attn_out)

            return src_layer_output, src_attn_prb

        # self-attentionの計算結果のみ返す
        elif not src_attn_flg:
            src_attn_out = self.src_attn(
                src_hidden,
                src_attn_mask,
                src_attn_flg)

            # Feed Forward,残差接続2
            intermediate_output = self.intrmed(src_attn_out)

            src_layer_output = self.src_out(
                intermediate_output,
                src_attn_out)

            return src_layer_output


class BertAttention(nn.Module):
    '''Multi-head attention + 全結合 + FFN後の残差接続 クラス'''
    def __init__(self, config: dict):
        super(BertAttention, self).__init__()
        self.src_selfattn = BertSelfAttention(config)
        self.src_output = BertSelfOutput(config)
        return

    def forward(self, hidden_state: torch.Tensor, attention_mask: torch.Tensor, attention_show_flg=False) -> torch.Tensor:
        """Multi-head attention + 全結合 + FFN後の残差接続(FFN後残差接続)

        Args:
            hidden_state (torch.Tensor): 特徴ベクトル
            attention_mask (torch.Tensor): マスク
            attention_show_flg (bool, optional): attentionマップを出力するか否か. Defaults to False.

        Returns:
            torch.Tensor: self-attention通過後の特徴ベクトル
        """
        # self-attentionの計算,attention mapも返す
        if attention_show_flg:
            self_output, attention_probs = self.src_selfattn(
                hidden_state,
                attention_mask,
                attention_show_flg)
            # チャネル変換,FFN後の残差接続
            attention_output = self.src_output(self_output, hidden_state)

            return attention_output, attention_probs

        # self-attentionの計算結果のみ返す
        elif not attention_show_flg:
            self_output = self.src_selfattn(
                hidden_state,
                attention_mask,
                attention_show_flg)
            # チャネル変換,FFN後の残差接続
            attention_output = self.src_output(self_output, hidden_state)

            return attention_output


class BertSelfAttention(nn.Module):
    '''scaled dot-product attention(SDA) + multi-head attention(MHA)クラス'''
    def __init__(self, config: dict):
        super(BertSelfAttention, self).__init__()

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        return

    def split_heads(self, x: torch.Tensor) -> torch.Tensor:
        """MHA(multi-head attention)用にテンソルの形を変換
        [batch_size, seq_len, hidden] → [batch_size, 12, seq_len, hidden/12]

        Args:
            x (torch.Tensor): 線形変換 + 活性化関数通過後のquery又はkey又はvalue

        Returns:
            torch.Tensor: num_attention_heads毎に分割された特徴ベクトル
        """
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_state: torch.Tensor, attention_mask: torch.Tensor=None,
     attention_show_flg=False) -> torch.Tensor:
        """scaled dot-product attention(SDA) + multi-head attention(MHA)

        Args:
            hidden_state (torch.Tensor): multi-headに分割される前の特徴ベクトル
            attention_mask (torch.Tensor, optional): マスク. Defaults to None.
            attention_show_flg (bool, optional): attentionマップを出力するか否か. Defaults to False.

        Returns:
            torch.Tensor: SDAとMHAを経て分割されたheadsを統合した特徴ベクトル
        """

        # 入力を全結合層で特徴量変換 ※MHAの全ヘッドをまとめて変換
        mixed_query_layer = self.query(hidden_state)
        mixed_key_layer = self.key(hidden_state)
        mixed_value_layer = self.value(hidden_state)

        # multi-headに分割
        query_layer = self.split_heads(mixed_query_layer)
        key_layer = self.split_heads(mixed_key_layer)
        value_layer = self.split_heads(mixed_value_layer)

        # 特徴ベクトル同士の類似度を行列の内積で求める
        attention_scores = torch.matmul( query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)

        # padding部分をゼロにするマスク処理 ※後でSoftmaxをかける為マスクには-10000(=正規化時に一番小さくなる値)を代入
        attention_mask = (1-attention_mask) * -1e4
        attention_scores = attention_scores + attention_mask

        # 正規化後,ドロップアウト
        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        attention_probs = self.dropout(attention_probs)

        # attention_probsとvalue_layerで行列の積
        context_layer = torch.matmul(attention_probs, value_layer)

        # multi-head Attentionのheadsを統合
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)

        # self-attention結果とattention mapも返す
        if attention_show_flg:
            return context_layer, attention_probs

        # self-attention結果のみ返す
        elif not attention_show_flg:
            return context_layer


class BertSelfOutput(nn.Module):
    """(線形変換 + 活性化関数) + 残差接続クラス(MHA後残差接続)"""
    def __init__(self, config: dict):
        super(BertSelfOutput, self).__init__()

        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-6)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        return

    def forward(self, hidden_state: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        """線形変換 + 活性化関数通過後の残差接続(MHA後残差接続)

        Args:
            hidden_state (torch.Tensor): SDA+MHA通過後の特徴ベクトル
            input_tensor (torch.Tensor): SDA+MHA通過前の特徴ベクトル

        Returns:
            torch.Tensor: 線形変換 + 活性化関数 + 残差接続後の特徴ベクトル
        """
        hidden_state = self.dropout(self.dense(hidden_state))
        hidden_state = self.LayerNorm(hidden_state + input_tensor)
        return hidden_state


def gelu(x: torch.Tensor) -> torch.Tensor:
    '''活性化関数gelu(Gaussian Error Linear Unit)
    0付近でなめらかなReLUのような関数)'''
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))


def mish(x: torch.Tensor) -> torch.Tensor:
    '''活性化関数Mish(ImageNetコンペで高い精度を出した活性化関数, 2020)
    0付近でなめらかな関数'''
    return x * torch.tanh(F.softplus(x))


def Swish(x: torch.Tensor) -> torch.Tensor:
    '''活性化関数Swish(Mishの前身となった活性化関数, 0付近でなめらかな関数)'''
    return x * torch.sigmoid(x)


class BertIntermediate(nn.Module):
    '''線形変換 + 活性化関数クラス'''
    def __init__(self, config: dict):
        super(BertIntermediate, self).__init__()

        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        self.intermediate_act_fn = gelu
        # self.intermediate_act_fn = mish
        return

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """線形変換(中間層サイズに次元を拡張) + 活性化関数

        Args:
            hidden_states (torch.Tensor): SDA+MHA通過後の特徴ベクトル

        Returns:
            torch.Tensor: SDA+MHA通過後に次元拡張した特徴ベクトル
        """
        hidden_states = self.intermediate_act_fn(self.dense(hidden_states))
        return hidden_states


class BertOutput(nn.Module):
    """(線形変換(元の次元数に圧縮) + レイヤ正規化) + 残差接続クラス"""
    def __init__(self, config):
        super(BertOutput, self).__init__()

        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-6)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        return

    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        """(線形変換(元の次元数に圧縮) + レイヤ正規化) + 残差接続

        Args:
            hidden_states (torch.Tensor): SDA+MHA通過+次元拡張した特徴ベクトル
            input_tensor (torch.Tensor) : SDA+MHA通過+次元圧縮した特徴ベクトル

        Returns:
            torch.Tensor: FFN通過後に残差接続された特徴ベクトル(FFN後残差接続)
        """
        hidden_states = self.dropout(self.dense(hidden_states))
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class BertEncoder(nn.Module):
    '''エンコーダクラス(指定した数だけBertLayerを何層積み重ねる)'''
    def __init__(self, config):
        super(BertEncoder, self).__init__()

        self.layer = nn.ModuleList(
            [BertLayer(config) for _ in range(config.num_hidden_layers)])
        return

    def forward(self, hidden_states: torch.Tensor,
                attention_mask: torch.Tensor=None,
                output_all_encoded_layers: bool=False,
                attention_show_flg: bool=False):
        """BERTのエンコード処理

        Args:
            hidden_states (torch.Tensor): 単語埋め込み後の特徴ベクトル
            attention_mask (torch.Tensor, optional): マスク. Defaults to None.
            output_all_encoded_layers (bool, optional): 全ての層の出力を返すか否か. Defaults to False.
            attention_show_flg (bool, optional): attentionマップを出力するか否か. Defaults to False.

        Returns:
            _type_: エンコードされた特徴ベクトル
        """

        all_encoder_layers = []

        # config.num_hidden_layers分だけBertLayerモジュールを繰り返し
        for layer_module in self.layer:

            # self-attention結果とattention mapを返す
            if attention_show_flg:
                hidden_states, attention_probs = layer_module(
                    hidden_states,
                    attention_mask,
                    attention_show_flg)

            # self-attention結果のみ返す
            elif not attention_show_flg:
                hidden_states = layer_module(
                    hidden_states,
                    attention_mask,
                    attention_show_flg)

            # config.num_hidden_layers分の全ての層を返す
            if output_all_encoded_layers:
                all_encoder_layers.append(hidden_states)

        # 最終層の結果のみ返す
        if not output_all_encoded_layers:
            all_encoder_layers.append(hidden_states)

        # self-attention結果とattention mapを返す
        if attention_show_flg:
            return all_encoder_layers, attention_probs

        # self-attention結果のみ返す
        elif not attention_show_flg:
            return all_encoder_layers


class BertPooler(nn.Module):
    """入力文章の1単語目[cls]の特徴量を線形変換して保持するクラス"""
    def __init__(self, config):
        super(BertPooler, self).__init__()

        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()
        return

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """先頭の特徴ベクトル取得 + 線形変換 + 活性化関数

        Args:
            hidden_states (torch.Tensor): BERTでエンコード済みの特徴ベクトル

        Returns:
            torch.Tensor: 先頭の特徴ベクトルのみを加工して得た特徴ベクトル
        """
        # 先頭の特徴ベクトルを取得
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.activation(self.dense(first_token_tensor))
        return pooled_output


class BertModel(nn.Module):
    '''BERTクラス'''
    def __init__(self, config):
        super(BertModel, self).__init__()

        self.embeddings = BertEmbeddings(config)
        self.encoder = BertEncoder(config)
        self.pooler = BertPooler(config)

    def forward(self, input_ids: torch.Tensor,
                attention_mask: torch.Tensor,
                token_type_ids: torch.Tensor=None,
                output_all_encoded_layers: bool=False,
                attention_show_flg: bool=False):
        """アテンションマスク作成 + 単語埋め込み + BERTによるエンコード

        Args:
            input_ids (torch.Tensor): 入力トークン列
            attention_mask (torch.Tensor): アテンション用マスク
            token_type_ids (torch.Tensor, optional): 文情報. Defaults to None.
            output_all_encoded_layers (bool, optional): 全ての層の出力を返すか否か. Defaults to False.
            attention_show_flg (bool, optional): attentionマップを出力するか否か. Defaults to False.

        Returns:
            _type_: BERTでエンコードされた特徴ベクトル
        """

        # attentionのマスクが無ければ作成
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        # 文の1文目、2文目のidが無ければ作成
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        # multi-head Attention用にマスクを変形 [minibatch, 1, 1, seq_length]
        extended_attention_mask = attention_mask.unsqueeze(1)
        extended_attention_mask = extended_attention_mask.to(dtype=torch.float32)

        # 単語埋め込み
        embedding_output = self.embeddings(input_ids, token_type_ids)

        # self-attention結果とattention mapを返す
        if attention_show_flg:
            encoded_layers, attention_probs = self.encoder(
                embedding_output,
                extended_attention_mask,
                output_all_encoded_layers,
                attention_show_flg)

        # self-attention結果のみ返す
        elif not attention_show_flg:
            encoded_layers = self.encoder(
                embedding_output,
                extended_attention_mask,
                output_all_encoded_layers,
                attention_show_flg)

        # 最終層の1文目の特徴量のみ取り出す ※誤り訂正では使わない
        pooled_output = self.pooler(encoded_layers[-1])

        # 最終層のself-attentionのみ返す
        if not output_all_encoded_layers:
            encoded_layers = encoded_layers[-1]

        # self-attention結果とattention mapを返す
        if attention_show_flg == True:
            return encoded_layers, attention_probs

        # self-attention結果のみ返す
        elif attention_show_flg == False:
            return encoded_layers
  • import
    平方根の計算に使うmathライブラリとモデル構築に使うtorchライブラリを読み込みます。
import math
import torch
import torch.nn.functional as F
from torch import nn

BERTクラスが持つ機能について、上から順に説明します。

  • class BERTModel
    • class BertEmbeddings:単語・文情報埋め込み部
      • class BertLayerNorm
    • class BertEncoder:単語ベクトルのエンコード部
      • class BertLayer × N層
        • class BertAttention:self-attention計算部
          • class BertSelfAttention
          • class BertSelfOutput
            • class BertLayerNorm
        • class BertIntermediate:Feed-Forward Network1計算部
        • class BertOutput:Feed-Forward Network2計算部
          • class BertLayerNorm
    • class BertPooler:特徴ベクトルの抽出部

  • class BertEmbeddings:単語・文情報埋め込み部

入力は分かち書き、id化されバッチ毎に取り出されたテンソルです。モデルの規模上GPUに転送済みであることを想定しています。

単語埋め込み層と位置埋め込み層、文の順序埋め込み層を定義します。入力された単語idをそれぞれのベクトルに埋め込んだ出力を加算します。単語埋め込み層の出力だけモデルサイズで調整されています。根拠は読み解けませんでしたが、論文に沿って値をスケールしています。

加算された単語ベクトルはレイヤ正規化層とドロップアウトを経て出力されます。ミニバッチ毎に平均と分散を求め、その後再度平均と分散をいくらシフトするかを学習するバッチ正規化に対して、レイヤ正規化は、1つの単語ベクトルの隠れ層の値で平均と分散を計算し、その後再度平均と分散をいくらシフトするかを学習します。これで可変長の入力に対しても安定した平均と分散が求められます。

レイヤ正規化の詳細な解説はこちらの記事をご参照下さい。
https://data-analytics.fun/2020/07/16/understanding-layer-normalization/

class BertEmbeddings(nn.Module):
    '''単語埋め込みクラス'''
    def __init__(self, config: dict):
        super(BertEmbeddings, self).__init__()

        self.config = config

        self.src_word_enc = nn.Embedding(
            self.config.vocab_size,
            self.config.hidden_size,
            padding_idx=2)

        self.src_pos_enc = nn.Embedding(
            config.max_position_embeddings,
            config.hidden_size,
            padding_idx=2)

        self.token_type_embeddings = nn.Embedding(
            config.type_vocab_size,
            config.hidden_size)

        self.LayerNorm = BertLayerNorm(
            config.hidden_size,
            eps=1e-6)

        self.dropout = nn.Dropout(
            config.hidden_dropout_prob)
        return

    def forward(self, input_ids: torch.Tensor, token_type_ids: torch.Tensor=None) -> torch.Tensor:
        """単語埋め込み

        Args:
            input_ids (torch.Tensor): idに変換した入力単語
            token_type_ids (torch.Tensor, optional): _description_. Defaults to None.

        Returns:
            torch.Tensor: 単語埋め込みベクトル
        """
        # 単語埋め込み
        src_word_emb = self.src_word_enc(input_ids)

        # 埋め込み後ベクトルの中間層の次元で調整
        # ※Transformerの原著論文で実装されている
        if self.config.scale_emb:
            src_word_emb *= self.config.hidden_size ** 0.5

        # 文章情報埋め込み
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        # 単語位置情報埋め込み
        position_ids = torch.arange(
            input_ids.size(1),
            dtype=torch.long,
            device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
        src_pos_emb = self.src_pos_enc(position_ids)

        # 全埋め込み結果を加算しレイヤ正規化とドロップアウトを適用
        embeddings = src_word_emb + src_pos_emb + token_type_embeddings
        embeddings = self.dropout(self.LayerNorm(embeddings))

        return embeddings

  • class BertEncoder:単語ベクトルのエンコード部

単語ベクトルをエンコードしていきます。BertLayerクラスで後述するMulti-head attention + Feed-Forward Networkブロックを何層重ねるか指定します。

入力は先程変換した単語ベクトルとpaddingの学習を防ぐマスクと、全ての層の出力を返すかを選択するフラグです。デフォルトでは最終層の出力のみ返すように実装されています。

class BertEncoder(nn.Module):
    '''エンコーダクラス(指定した数だけBertLayerを何層積み重ねる)'''
    def __init__(self, config):
        super(BertEncoder, self).__init__()

        self.layer = nn.ModuleList(
            [BertLayer(config) for _ in range(config.num_hidden_layers)])
        return

    def forward(self, hidden_states: torch.Tensor,
                attention_mask: torch.Tensor=None,
                output_all_encoded_layers: bool=False,
                attention_show_flg: bool=False):
        """BERTのエンコード処理

        Args:
            hidden_states (torch.Tensor): 単語埋め込み後の特徴ベクトル
            attention_mask (torch.Tensor, optional): マスク. Defaults to None.
            output_all_encoded_layers (bool, optional): 全ての層の出力を返すか否か. Defaults to False.
            attention_show_flg (bool, optional): attentionマップを出力するか否か. Defaults to False.

        Returns:
            _type_: エンコードされた特徴ベクトル
        """

        all_encoder_layers = []

        # config.num_hidden_layers分だけBertLayerモジュールを繰り返し
        for layer_module in self.layer:

            # self-attentionの結果とattention mapを返す
            if attention_show_flg:
                hidden_states, attention_probs = layer_module(
                    hidden_states,
                    attention_mask,
                    attention_show_flg)

            # self-attentionの結果のみ返す
            elif not attention_show_flg:
                hidden_states = layer_module(
                    hidden_states,
                    attention_mask,
                    attention_show_flg)

            # config.num_hidden_layers分の全ての層を返す
            if output_all_encoded_layers:
                all_encoder_layers.append(hidden_states)

        # 最終層の結果のみ返す
        if not output_all_encoded_layers:
            all_encoder_layers.append(hidden_states)

        # self-attentionの結果とattention mapを返す
        if attention_show_flg:
            return all_encoder_layers, attention_probs

        # self-attentionの結果のみ返す
        elif not attention_show_flg:
            return all_encoder_layers

  • class BertAttention:self-attention計算部
     scaled dot-product attention(SDA)とmulti-head attention(MHA)の実装です。単語ベクトルを分割するヘッド数に基づいて1ヘッドあたりの単語ベクトルのサイズを計算しておきます。さらに、単語ベクトルをquery, key, valueに変換するための全結合層も用意します。
class BertSelfAttention(nn.Module):
    '''scaled dot-product attention(SDA) + multi-head attention(MHA)クラス'''
    def __init__(self, config: dict):
        super(BertSelfAttention, self).__init__()

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        return

    def split_heads(self, x: torch.Tensor) -> torch.Tensor:
        """MHA(multi-head attention)用にテンソルの形を変換
        [batch_size, seq_len, hidden] → [batch_size, 12, seq_len, hidden/12]

        Args:
            x (torch.Tensor): 線形変換 + 活性化関数通過後のquery又はkey又はvalue

        Returns:
            torch.Tensor: num_attention_heads毎に分割された特徴ベクトル
        """
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_state: torch.Tensor, attention_mask: torch.Tensor=None,
     attention_show_flg=False) -> torch.Tensor:
        """scaled dot-product attention(SDA) + multi-head attention(MHA)

        Args:
            hidden_state (torch.Tensor): multi-headに分割される前の特徴ベクトル
            attention_mask (torch.Tensor, optional): マスク. Defaults to None.
            attention_show_flg (bool, optional): attentionマップを出力するか否か. Defaults to False.

        Returns:
            torch.Tensor: SDAとMHAを経て分割されたheadsを統合した特徴ベクトル
        """

        # 入力を全結合層で特徴量変換 ※MHAの全ヘッドをまとめて変換
        mixed_query_layer = self.query(hidden_state)
        mixed_key_layer = self.key(hidden_state)
        mixed_value_layer = self.value(hidden_state)

        # multi-headに分割
        query_layer = self.split_heads(mixed_query_layer)
        key_layer = self.split_heads(mixed_key_layer)
        value_layer = self.split_heads(mixed_value_layer)

        # 特徴ベクトル同士の類似度を行列の内積で求める
        attention_scores = torch.matmul( query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)

        # padding部分をゼロにするマスク処理 ※後でSoftmaxをかける為マスクには-10000(=正規化時に一番小さくなる値)を代入
        attention_mask = (1-attention_mask) * -1e4
        attention_scores = attention_scores + attention_mask

        # 正規化後,ドロップアウト
        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        attention_probs = self.dropout(attention_probs)

        # attention_probsとvalue_layerで行列の積
        context_layer = torch.matmul(attention_probs, value_layer)

        # multi-head Attentionのheadsを統合
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)

        # self-attention結果とattention mapも返す
        if attention_show_flg:
            return context_layer, attention_probs

        # self-attention結果のみ返す
        elif not attention_show_flg:
            return context_layer

このクラスで単語ベクトルは2つのクラスに入力され、以下の手順で計算されます。一つはBertSelfAttentionクラスです。BertSelfAttentionクラスでは以下の処理を実行します。

1.単語ベクトルをquery,key,valueへ変換
2.複数ヘッドへ分割
3.queryとkeyで単語ベクトル同士の類似度を行列の内積から計算(attention score)
4.padding部の学習を防ぐためのマスク
5.valueとattention scoreで行列の積を取り単語ベクトルの重み付き加重和を計算
6.複数ヘッドを単一のヘッドへ統合

class BertSelfAttention(nn.Module):
    '''scaled dot-product attention(SDA) + multi-head attention(MHA)クラス'''
    def __init__(self, config: dict):
        super(BertSelfAttention, self).__init__()

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        return

    def split_heads(self, x: torch.Tensor) -> torch.Tensor:
        """MHA(multi-head attention)用にテンソルの形を変換
        [batch_size, seq_len, hidden] → [batch_size, 12, seq_len, hidden/12]

        Args:
            x (torch.Tensor): 線形変換 + 活性化関数通過後のquery又はkey又はvalue

        Returns:
            torch.Tensor: num_attention_heads毎に分割された特徴ベクトル
        """
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_state: torch.Tensor, attention_mask: torch.Tensor=None,
     attention_show_flg=False) -> torch.Tensor:
        """scaled dot-product attention(SDA) + multi-head attention(MHA)

        Args:
            hidden_state (torch.Tensor): multi-headに分割される前の特徴ベクトル
            attention_mask (torch.Tensor, optional): マスク. Defaults to None.
            attention_show_flg (bool, optional): attentionマップを出力するか否か. Defaults to False.

        Returns:
            torch.Tensor: SDAとMHAを経て分割されたheadsを統合した特徴ベクトル
        """

        # 入力を全結合層で特徴量変換 ※MHAの全ヘッドをまとめて変換
        mixed_query_layer = self.query(hidden_state)
        mixed_key_layer = self.key(hidden_state)
        mixed_value_layer = self.value(hidden_state)

        # multi-headに分割
        query_layer = self.split_heads(mixed_query_layer)
        key_layer = self.split_heads(mixed_key_layer)
        value_layer = self.split_heads(mixed_value_layer)

        # 特徴ベクトル同士の類似度を行列の内積で求める
        attention_scores = torch.matmul( query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)

        # padding部分をゼロにするマスク処理 ※後でSoftmaxをかける為マスクには-10000(=正規化時に一番小さくなる値)を代入
        attention_mask = (1-attention_mask) * -1e4
        attention_scores = attention_scores + attention_mask

        # 正規化後,ドロップアウト
        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        attention_probs = self.dropout(attention_probs)

        # attention_probsとvalue_layerで行列の積
        context_layer = torch.matmul(attention_probs, value_layer)

        # multi-head Attentionのheadsを統合
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)

        # self-attention結果とattention mapも返す
        if attention_show_flg:
            return context_layer, attention_probs

        # self-attention結果のみ返す
        elif not attention_show_flg:
            return context_layer

このクラスを通った後、BertSelfOutputクラスに入力され以下の計算を行います。

1.入出力が同じ次元数の全結合層
2.ドロップアウト
3.残差接続
4.レイヤ正規化

class BertSelfOutput(nn.Module):
   """(線形変換 + 活性化関数) + 残差接続クラス(MHA後残差接続)"""
   def __init__(self, config: dict):
       super(BertSelfOutput, self).__init__()

       self.dense = nn.Linear(config.hidden_size, config.hidden_size)
       self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-6)
       self.dropout = nn.Dropout(config.hidden_dropout_prob)
       return

   def forward(self, hidden_state: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
       """線形変換 + 活性化関数通過後の残差接続(MHA後残差接続)

       Args:
           hidden_state (torch.Tensor): SDA+MHA通過後の特徴ベクトル
           input_tensor (torch.Tensor): SDA+MHA通過前の特徴ベクトル

       Returns:
           torch.Tensor: 線形変換 + 活性化関数 + 残差接続後の特徴ベクトル
       """
       hidden_state = self.dropout(self.dense(hidden_state))
       hidden_state = self.LayerNorm(hidden_state + input_tensor)
       return hidden_state

  • class BertIntermediate:Feed-Forward Network1計算部

Multi-head attentionと全結合層、残差接続を通った単語ベクトルは、Feed-Forward層に入力されます。ここでは出力次元数を拡張した全結合層と活性関数を適用するのみです。活性化関数にはgeluが適用されていますが、ReLUと比較したときの特徴はゼロ付近でなめらかにゼロに近づいていく関数だという点です。
 正直他にもmishやswishなど様々な活性化関数が提案されており実装してみましたが、タスク達成度に大きく貢献しているかは明らかにできませんでした。タスクとデータセットに応じて使い分けるものだと認識しています。

class BertIntermediate(nn.Module):
    '''線形変換 + 活性化関数クラス'''
    def __init__(self, config: dict):
        super(BertIntermediate, self).__init__()

        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        self.intermediate_act_fn = gelu
        # self.intermediate_act_fn = mish
        return

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """線形変換(中間層サイズに次元を拡張) + 活性化関数

        Args:
            hidden_states (torch.Tensor): SDA+MHA通過後の特徴ベクトル

        Returns:
            torch.Tensor: SDA+MHA通過後に次元拡張した特徴ベクトル
        """
        hidden_states = self.intermediate_act_fn(self.dense(hidden_states))
        return hidden_states

  • class BertOutput:Feed-Forward Network2計算部

出力次元数を拡張された単語ベクトルはこのクラスで再度、全結合層を経て元の次元数へ圧縮されます。その後、レイヤ正規化、ドロップアウトを経て元のサイズの単語ベクトルのサイズで出力されます。憶測ですが、この出力次元数の拡張と圧縮は、高い表現力の獲得とパラメータの削減、再度BertLayer層を積み重ねるためのサイズ調整などの気持ちが含まれていると思いました。

class BertOutput(nn.Module):
    """(線形変換(元の次元数に圧縮) + レイヤ正規化) + 残差接続クラス"""
    def __init__(self, config):
        super(BertOutput, self).__init__()

        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-6)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        return

    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        """(線形変換(元の次元数に圧縮) + レイヤ正規化) + 残差接続

        Args:
            hidden_states (torch.Tensor): SDA+MHA通過+次元拡張した特徴ベクトル
            input_tensor (torch.Tensor) : SDA+MHA通過+次元圧縮した特徴ベクトル

        Returns:
            torch.Tensor: FFN通過後に残差接続された特徴ベクトル(FFN後残差接続)
        """
        hidden_states = self.dropout(self.dense(hidden_states))
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states

以上の処理でBERTによる入力文章の単語ベクトルのエンコードは終了です。出力は[バッチサイズ,入力トークン数,単語ベクトルの次元数]となります。後述するクラスは文章分類などにBERTを使う際に必要となるクラスです。


  • class BertPooler:特徴ベクトルの抽出部

BERTを使ってクラス分類などのタスクを解く際に用いるのがこのクラスです。本来であれば入力文章をエンコードした全ての単語ベクトルを利用して分類問題を解くのが直感的のように思いますが、BERTの論文ではこの内の最初のトークン(前処理の分かち書きで[CLS]トークンが入っています)のベクトルのみを取り出して全結合層を適用して求めたいクラス分類数の出力に変換します。

class BertPooler(nn.Module):
    """入力文章の1単語目[cls]の特徴量を線形変換して保持するクラス"""
    def __init__(self, config):
        super(BertPooler, self).__init__()

        self.dense = nn.Linear(config.hidden_size, config.num_class)
        self.activation = nn.Tanh()
        return

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """先頭の特徴ベクトル取得 + 線形変換 + 活性化関数

        Args:
            hidden_states (torch.Tensor): BERTでエンコード済みの特徴ベクトル

        Returns:
            torch.Tensor: 先頭の特徴ベクトルのみを加工して得た特徴ベクトル
        """
        # 先頭の特徴ベクトルを取得
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.activation(self.dense(first_token_tensor))
        return pooled_output

先頭のトークンだけでタスクが解ける理由として考えられるのは、これまで単語ベクトルが通過してきたmulti-head attentionによるheadの分割と統合や、残差接続の積み重ねによって全てのトークンに入力文章の情報が埋め込まれているためだと理解しています。
 このあたりはBERTの学習がどのように進んでいるのか、という話になると思いますが、明確な解釈はまだ持っていないです。ご存知の方がいましたら教えていただけると幸いです。
 以下のBERTVizというライブラリを使ってBERTのself-attentionの重みを可視化したこともありますが、12層×12ヘッド=144個のself-attentionの重みを人間が解釈するのは困難だと感じました。
https://github.com/jessevig/bertviz

一方でBERTではなくTransformerですが、学習過程でどんな作用が支配的であるかを分析された以下のような論文もあります。非常に勉強になります。
https://www.anlp.jp/proceedings/annual_meeting/2021/pdf_dir/A7-2.pdf


  • class BERTModel

最終的に今回実装したBERTModelは以下になります。入力トークンを単語ベクトルに変換し、multi-head attentionとself-attentionで単語ベクトルをエンコードし、分類タスクを解くなら先頭のトークンのみを選んで全結合層を適用する、という流れになります。

class BertModel(nn.Module):
    '''BERTクラス'''
    def __init__(self, config):
        super(BertModel, self).__init__()

        self.embeddings = BertEmbeddings(config)
        self.encoder = BertEncoder(config)
        self.pooler = BertPooler(config)

    def forward(self, input_ids: torch.Tensor,
                attention_mask: torch.Tensor,
                token_type_ids: torch.Tensor=None,
                output_all_encoded_layers: bool=False,
                attention_show_flg: bool=False):
        """アテンションマスク作成 + 単語埋め込み + BERTによるエンコード

        Args:
            input_ids (torch.Tensor): 入力トークン列
            attention_mask (torch.Tensor): アテンション用マスク
            token_type_ids (torch.Tensor, optional): 文情報. Defaults to None.
            output_all_encoded_layers (bool, optional): 全ての層の出力を返すか否か. Defaults to False.
            attention_show_flg (bool, optional): attentionマップを出力するか否か. Defaults to False.

        Returns:
            _type_: BERTでエンコードされた特徴ベクトル
        """

        # attentionのマスクが無ければ作成
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        # 文の1文目、2文目のidが無ければ作成
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        # multi-head Attention用にマスクを変形 [minibatch, 1, 1, seq_length]
        extended_attention_mask = attention_mask.unsqueeze(1)
        extended_attention_mask = extended_attention_mask.to(dtype=torch.float32)

        # 単語埋め込み
        embedding_output = self.embeddings(input_ids, token_type_ids)

        # self-attention結果とattention mapを返す
        if attention_show_flg:
            encoded_layers, attention_probs = self.encoder(
                embedding_output,
                extended_attention_mask,
                output_all_encoded_layers,
                attention_show_flg)

        # self-attention結果のみ返す
        elif not attention_show_flg:
            encoded_layers = self.encoder(
                embedding_output,
                extended_attention_mask,
                output_all_encoded_layers,
                attention_show_flg)

        # 最終層の1文目の特徴量のみ取り出す ※誤り訂正では使わない
        pooled_output = self.pooler(encoded_layers[-1])

        # 最終層のself-attentionのみ返す
        if not output_all_encoded_layers:
            encoded_layers = encoded_layers[-1]

        # self-attention結果とattention mapを返す
        if attention_show_flg == True:
            return encoded_layers, attention_probs

        # self-attention結果のみ返す
        elif attention_show_flg == False:
            return encoded_layers

まとめ

自然言語処理で数年前にデファクトスタンダートとなったBERTを実装しました。現在はこの技術を基に自然言語処理は目覚ましい発展を遂げ、特に生成モデルは大きな進化を遂げています。

現在はhugginfaceなどで学習済みのモデルがすぐに使える時代で非常に便利になりましたが、当時はhuggingfaceはまだ作られたばかりで、やっぱり自分で実装するのが一番でした。今回のコメント付けを通して理解も深められたので、またできることが一つ増やせたと思います。

Discussion