🤗

【SIGNATE】BERTで医療論文を2値分類する(PyTorch BERT)

2021/10/18に公開

医学論文の自動仕分けチャレンジ

コンペの内容

https://signate.jp/competitions/471/
「医学論文の自動仕分けチャレンジ」 網羅的に収集された論文の中から、目的に沿った論文のみを抽出しよう!
というテーマのAIコンペがSIGNATE上で2021年7月28日(水) - 2021年10月4日(月) の間に開催されました。

結果

コンペの結果としては、<78位/637人投稿> となりました。

私にとっては初のNLPコンペであり、とても多くのことを学ぶことが出来たと思います。
実は今まではKerasを使用していたのですが、今回のコンペではPyTorchを使用しました。
NLP, PyTorchともに初学者ですが、だからこその目線で解説を残しておこうとおもいます!
本コンペで実装したものは、すべて↓にNotebookで掲載しています。雑なコードが多くありますが、参考になりましたら幸いです。
https://github.com/kubokoHappy/Classification_MedicalPapers_PyTorch

実装の解説

提供されたデータ

実際に使用した生データは公開することが出来ないため、ダミーのデータを記載しています。
論文の「Title」「Abstract」「ラベル 0 or 1」のデータが提供されていました。

Title Abstract Judgement
KUBOKO is crazy KUBOKO is always optimistic. He's name ... 0
Who is KUBOKO The true identity of KUBOKO is a human ... 0
KUBOKO's Favorite Food KUBOKO's favorite foods are small tomatoes ... 1
KUBOKO's Hobby KUBOKO's hobby is watching horror movies. 0

何をしたのか(概要)

  • 🤗Huggingface Transformersで提供されているmicrosoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltextをベースにして、医療論文の2値分類用にFine tuningしました。
  • Modelには、上記のBERTをベースとして、LSTM, Conv1D, Linear層を追加し、BERTの重みを最大限活かした予測ができるように工夫しています。
  • Datasetには、Argument(データ拡張)処理を実装し、学習データの文章をランダムに削除したり入れ替えることで過学習の抑制をしました。
  • ラベル1が全体のうちの 1/43 程度しかなかったこと、評価指標がラベル1の正解を高く評価する指標であることから、損失関数のラベル1に対する重みを130倍 (ヒューリスティックス) に設定した。

Dataset

PyTorchのDatasetクラスを継承して、CustomDatasetを作成しています。

今回はネットワークに入力するデータとして、「Title」と「Abstract」を単純に結合した「title_abstract」を選択しています。これはBERTへできる限り多くの情報を入力し、かつ単純化するためです。

工夫点としては、Dataset内にデータ拡張処理を実装しているところです。
学習ループを回す際に、毎回同一の文章を入力してしまうと、それぞれの文章全体に対してのラベルを学習してしまい、学習データに適合しすぎてしまいます(過学習)。
文章内での文の順番をランダムに変更したり、文を削除することで、過学習を防ごうという目的で実装しました。
画像データだと、データ拡張は当たり前に用いられています。

PyTorch初心者でDatasetについてよくわからない方は、↓のPyTorchチュートリアル(日本語翻訳版)を一度進めることをおすすめします!とてもわかりやすいです!!
https://yutaroogawa.github.io/pytorch_tutorials_jp/

class TextClassificationDataset(Dataset):
    def __init__(self, df, tokenizer, use_col='title_abstract', 
                 token_max_length=512, argument=False, upsample_pos_n=1):
        """
        Custom Dataset

        Attributes
        ----------
        df : DataFrame
            元のデータを保持しているDataFrame
        tokenizer : tokenizer
            Transformersからダウンロードしたトークナイザー
            これを使用してtext から BERTへ入力できるTokensへと変換する
        use_col : str
            使用する列名
            多くの情報を扱いたいため、Defaultは「title_abstract」
        token_max_length : 
            BERTに入力するtokenサイズ
        argument : bool
            True -> データ拡張あり
        upsample_pos_n : int
            ラベル1(Judgement:1)をアップサンプリングする倍率
            1以上の場合は、指定した倍率でアップサンプリングする
        """

        if upsample_pos_n > 1:
            df_pos = df.loc[df.judgement==1]
            df_pos = pd.concat([df_pos for i in range(int(upsample_pos_n))], axis=0).reset_index(drop=True)
            df_neg = df.loc[df.judgement==0]
            self.df = pd.concat([df_pos, df_neg], axis=0).reset_index(drop=True)
        else:
            self.df = df
        
        self.tokenizer = tokenizer
        self.argument = argument
        self.use_col = use_col

    def text_argument(self, text, drop_min_seq=3, seq_sort=True):
        """
        Text Argument(データ拡張)処理

        Parameters
        ----------
        text : str
            データセットの各要素(テキスト)を挿入
        drop_min_seq : int
            文章がdrop_min_seq個以上ある場合('.'で分割)にデータ拡張を実施
        seq_sort : bool
            ランダムに選択した要素を、元の並び順にをソートするか

        Returns : str
            拡張済みのテキスト
        """
        seq_list = text.split('. ')
        seq_len = len(seq_list)
        if seq_len >= drop_min_seq:
            orig_idx_list = list(range(0, seq_len))
            idx_list = random.sample(orig_idx_list, random.randint(round(seq_len * 0.7), seq_len))
            if seq_sort:
                idx_list = sorted(idx_list)
            insert_idx_list = random.sample(orig_idx_list, random.randint(0, seq_len//3))
            for x in insert_idx_list:
                idx = random.randint(0, len(idx_list))
                idx_list.insert(idx, x)
            seq_list = [seq_list[i] for i in idx_list]
        text = '. '.join(seq_list)
        return text

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        
        text = self.df.loc[idx, self.use_col]

        if self.argument:
            text = self.text_argument(text, drop_min_seq=3, seq_sort=True)

        token = self.tokenizer.encode_plus(
            text,
            padding = 'max_length', max_length = hps.token_max_length, truncation = True,
            return_attention_mask=True, return_tensors='pt'
        )

        sample = dict(
            input_ids=token['input_ids'][0],
            attention_mask=token['attention_mask'][0]
        )
        
        label = torch.tensor(self.df.loc[idx, 'judgement'], dtype=torch.float32)
        return sample, label

Model

🤗Huggingface Transformersで提供されているmicrosoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltextをベースにして、医療論文の2値分類用にFine tuningしました。
Modelには、上記のBERTをベースとして、LSTM, Conv1D, Linear層を追加し、BERTの重みを最大限活かした予測ができるように工夫しています。

Modelを実装する際には、AI SHIFTさんの↓の記事を参考にさせていただきました。BERTの出力層にLSTMやCNNを活用することで精度改善につながるといった内容です。
https://www.ai-shift.co.jp/techblog/2145

class BertLstmExModel(nn.Module):
    def __init__(self, hidden_size, config, use_hidden_n=10):
        super().__init__()
        
        self.bert = transformers.AutoModel.from_pretrained(hps.model_name, config=bert_config)
        self.hidden_size = hidden_size
        self.use_hidden_n = use_hidden_n
        self.lstm = nn.LSTM(self.hidden_size, self.hidden_size, batch_first=True, bidirectional=True)
        self.leakyrelu = nn.LeakyReLU()
        self.dropout = nn.Dropout(p=0.3)
        self.conv1d = nn.Conv1d(in_channels=self.use_hidden_n, out_channels=1, kernel_size=3, padding='same')
        self.regressor = nn.Linear(self.hidden_size*2, 1)
        
    
    def forward(self, input_ids, attention_mask):
        # ベースとしているPubMedBERTからの出力
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        hidden_states_list = [outputs['hidden_states'][-1*i] for i in range(1, self.use_hidden_n+1)]
        self.lstm.flatten_parameters()  # メモリを圧縮している(しないと警告がでる為)
        # use_hidden_n 個独立したLSTMネットワークを作成している。
        out_list = [
            self.dropout(
                self.leakyrelu(
                    self.lstm(hidden_state, None)[0]
                )[:, -1, :]
            ).view(-1, 1, self.hidden_size*2)  # (batch, use_hidden_n, hidden_size*2)
        for hidden_state in hidden_states_list]

        # それぞれのLSTMネットワークからの出力を縦に結合している
        out = torch.cat(out_list, dim=1)

        # 結合された重みに1次元の畳込みをしている
        out = self.dropout(self.leakyrelu(self.conv1d(out)))
        out = out.view(out.size(0), -1)

        logits = torch.flatten(self.regressor(out))
        return logits

Optimizer, LR_Scheduler

OptimizerにはAdamWを使用しており、学習率は層ごとの大きなくくりごとに設定しています。
全体としてみれば、層の下にいくにつれて学習率を小さく設定しています。
これは、BERTの上の層は汎用的な情報を学習しているとかんがえられるためであるからです。

lr_schedulerには、Transformersで提供されているLR_Schedulerを使用しています。
https://huggingface.co/transformers/main_classes/optimizer_schedules.html#transformers.get_linear_schedule_with_warmup
全体の1/10に達した時点を学習率の頂点として、そこからは直線状に減衰させています。

def model_setup(model, dataloaders):
    optimizer = optim.AdamW(
        params=[
            {'params': model.bert.embeddings.parameters(), 'lr': 1e-5},
            {'params': model.bert.encoder.parameters(), 'lr': 2e-5},
            {'params': model.bert.pooler.parameters(), 'lr': 3e-5},
            {'params': model.lstm.parameters(), 'lr': 5e-4},
            {'params': model.conv1d.parameters(), 'lr': 5e-4},
            {'params': model.regressor.parameters(), 'lr': 5e-4}
        ]
    )
    num_warmup_steps = round(hps.num_epochs * len(dataloaders['train']) * 0.1)
    num_training_steps = round(hps.num_epochs * len(dataloaders['train']))
    print(f"InitLR:{hps.initial_lr} / num_warmup_steps:{num_warmup_steps} / num_training_steps:{num_training_steps}")
    lr_scheduler = transformers.get_linear_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=num_warmup_steps, 
                                                                num_training_steps=num_training_steps, last_epoch=-1)

    return (optimizer, lr_scheduler)

Critetion (損失関数)

損失関数には、BinaryCrossEntropyを使用しました。
PyTorchでこれを実装するには、torch.nn.BCELosstorch.nn.BCEWithLogitsLossの2種類があります。
今回は「BCEWithLogitsLoss」を使用したのですが、「BCELoss」と比べた際のメリットとして公式ページに次のように記載されていました。

This loss combines a Sigmoid layer and the BCELoss in one single class. This version is more numerically stable than using a plain Sigmoid followed by a BCELoss as, by combining the operations into one layer, we take advantage of the log-sum-exp trick for numerical stability.
(和訳)この損失はシグモイド層とBCELossを1つのクラスにまとめたものです。このバージョンは、単純なSigmoidとBCELossの組み合わせよりも数値的に安定しています。

また、「BCEWithLogitsLoss」では損失計算時に重みを設定することが出来ます。
今回与えられたデータでは、全体に対してラベル1が1/43程度と、不均衡なデータとなっていました。
加えて、本コンペの評価指標が「FBetaScore(β=7)」となっており、Recall(再現率)をより重視することから、ラベル1をより重要視した損失関数を設定する必要がありました。
そのため、「class_1_weight = 130」としてラベル1(Judgement 1)の重みを130と大きく設定しています。(130にたどり着いたのは、手動のパラメータチューニングによるもので、数式から導き出したものではありません。)

hps.class_1_weight = 130  # ラベル1(Judgement 1)の損失計算時の重みをラベル0の130倍に設定
pos_weight = torch.tensor([hps.class_1_weight for i in range(input_ids.size(0))]).to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

まとめ

冒頭にも記載していますが、本記事で紹介しているソースコードはすべてGithubで公開しています。
https://github.com/kubokoHappy/Classification_MedicalPapers_PyTorch
今後はKaggleにも積極的に挑戦しようと考えています。チームにご招待してくださる方いらっしゃいましたらお声がけお待ちしております🙃

雑な部分も多いかと思いますが、少しでも皆様の参考になる部分があると嬉しいです!
もしよろしければ「いいね」や「サポート」「Twitterフォロー」も宜しくお願いいたします!!
https://twitter.com/AI_kuboko

参考

https://huggingface.co/microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext
https://signate.jp/competitions/471
https://www.ai-shift.co.jp/techblog/2138
https://www.ai-shift.co.jp/techblog/2145
https://www.ai-shift.co.jp/techblog/2170

Discussion