🚤

CNN1DとLSTMとTransformerを映画感想文の判定をして比較してみる。

2023/06/27に公開

はじめに

CNN1D を用いたモデル、LSTM を用いたモデル、Transformer を用いたモデルの三種類のモデルについて映画の感想文の判定を行う。

データの説明

この記事では以下のデータセットを用いる。これは 2.5 万件の映画のレビューをポジティブ/ネガティブで分類したデータセットである。

https://www.tensorflow.org/datasets/catalog/imdb_reviews?hl=ja

このデータの前処理については前の記事で解説した。

モデル

CNN1D

CNN1D を用いたモデルを構築した。

# CNN1D

model = Sequential([
    Input(shape = (SEQUENCE_LENGTH,)),
    Embedding(VOCAB_SIZE, EMBEDDING_DIM),
    Conv1D(EMBEDDING_DIM, 3, activation='relu', padding='same'),
    Conv1D(EMBEDDING_DIM, 3, activation='relu', padding='same'),
    GlobalMaxPool1D(),
    Dense(EMBEDDING_DIM, activation='relu'),
    Dropout(0.5),
    Dense(1, activation='sigmoid')
])
model.summary()
_________________________________________________________________
 Layer (type)                Output Shape              Param #
=================================================================
 embedding (Embedding)       (None, 256, 512)          15360000

 conv1d (Conv1D)             (None, 256, 512)          786944

 conv1d_1 (Conv1D)           (None, 256, 512)          786944

 global_max_pooling1d (Globa  (None, 512)              0
 lMaxPooling1D)

 dense (Dense)               (None, 512)               262656

 dropout (Dropout)           (None, 512)               0

 dense_1 (Dense)             (None, 1)                 513

=================================================================
Total params: 17,197,057
Trainable params: 17,197,057
Non-trainable params: 0
_________________________________________________________________

LSTM

双方向 LSTM でモデルを構築した。

# LSTM

model = Sequential([
    Input(shape = (SEQUENCE_LENGTH,)),
    Embedding(VOCAB_SIZE, EMBEDDING_DIM, mask_zero=True),
    Bidirectional(LSTM(EMBEDDING_DIM, dropout=0.2)),
    Dropout(0.5),
    Dense(1, activation='sigmoid')
])

model.summary()
_________________________________________________________________
 Layer (type)                Output Shape              Param #
=================================================================
 embedding (Embedding)       (None, 256, 512)          15360000

 bidirectional (Bidirectiona  (None, 1024)             4198400
 l)

 dropout (Dropout)           (None, 1024)              0

 dense (Dense)               (None, 1)                 1025

=================================================================
Total params: 19,559,425
Trainable params: 19,559,425
Non-trainable params: 0
_________________________________________________________________

Transformer

Transformer を用いたモデルを構築した。

# Transformer

model = Sequential([
    Input(shape = (SEQUENCE_LENGTH,)),
    PositionalEmbedding(SEQUENCE_LENGTH, VOCAB_SIZE, EMBEDDING_DIM),
    Dropout(DROPOUT_RATE),
    EncoderTransformer(NUM_HEADS),
    EncoderTransformer(NUM_HEADS),
    GlobalAveragePooling1D(),
    Dropout(DROPOUT_RATE),
    Dense(1, activation='sigmoid')
])

model.summary()

EncoderTransformerはこの論文[1]の Encoder の transformer である。

_________________________________________________________________
 Layer (type)                Output Shape              Param #
=================================================================
 positional_embedding (Posit  (None, 256, 512)         15360000
 ionalEmbedding)

 dropout (Dropout)           (None, 256, 512)          0

 encoder_transformer (Encode  (None, 256, 512)         6302720
 rTransformer)

 encoder_transformer_1 (Enco  (None, 256, 512)         6302720
 derTransformer)

 global_average_pooling1d (G  (None, 512)              0
 lobalAveragePooling1D)

 dropout_5 (Dropout)         (None, 512)               0

 dense_4 (Dense)             (None, 1)                 513

=================================================================
Total params: 27,965,953
Trainable params: 27,965,953
Non-trainable params: 0
_________________________________________________________________

おわりに

CNN1D を用いた最も簡単なモデルが学習も早いし推論も早く一番いいモデルとなった。ハイパーパラメータの探索してないことと、データ量が少ないことが原因だと思う。特に Transformer はデータ量が大きい時に生きるのかなと思う。bert を転移学習するのが良いのかな。

また、実社会でこういったモデルを活用する場合、ROC 曲線や F 値などを可視化しないとモデルの良し悪しを判定できない。kubeflow とか mlflow とかを勉強してこの辺を可視化して、モデルを適切に選択できるようになりたい。そのためには k8s 勉強しないと。

脚注
  1. https://arxiv.org/abs/1706.03762 ↩︎

Discussion