💬

LSTMのEmbeddingレイヤーで、事前に学習されたword2vecを利用する方法

2023/06/03に公開

はじめに

LSTM  (Long short-term memory)は時系列データや文章の処理で用いられるニューラルネットワークの手法である。自然言語処理のタスクにおいて LSTM は、入力として単語のベクトル表現を受け取り内部のメモリを更新しながら次に現れる単語を予想するタスクを解きながら文全体の意味をベクトルに埋め込む。一方で、word2vec は単語の分散表現の学習手法である。word2vec は大量のテキストデータを与えることで、単語の関係性を捉え単語のベクトル空間を学習することができる。そのため、事前に word2vec を用いて単語の意味を学習させてから、それを LSTM 入力に利用することで LSTM のモデルの学習精度が向上させることができる。

本記事では TensorFlow で作成した LSTM のモデルにおいて、単語のベクトル表現の部分を事前に学習した word2vec のモデルに置き換えることでモデルの精度が向上することを確認する。

この記事のコードは以下に配置した。本記事では重要な部分を説明する。

データの解説

この記事では以下のデータセットを用いる。これは 2.5 万件の映画のレビューをポジティブ/ネガティブで分類したデータセットである。

https://www.tensorflow.org/datasets/catalog/imdb_reviews?hl=ja

import tensorflow as tf
import tensorflow_datasets as tfds

# Hyperparameters

VOCAB_SIZE = 10000
SEQUENCE_LENGTH = 250
EMBEDDING_DIM = 300
BATCH_SIZE = 64

# Load the data

raw_train_dataset, raw_val_dataset, raw_test_dataset = tfds.load('imdb_reviews', split=['train', 'test[:50%]', 'test[50%:]'], as_supervised=True)
for review, label in raw_train_dataset.take(2):
    print(review)
    print(label)

# tf.Tensor(b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it.", shape=(), dtype=string)
# tf.Tensor(0, shape=(), dtype=int64)
# tf.Tensor(b'I have been known to fall asleep during films, but this is usually due to a combination of things including, really tired, being warm and comfortable on the sette and having just eaten a lot. However on this occasion I fell asleep because the film was rubbish. The plot development was constant. Constantly slow and boring. Things seemed to happen, but with no explanation of what was causing them or why. I admit, I may have missed part of the film, but i watched the majority of it and everything just seemed to happen of its own accord without any real concern for anything else. I cant recommend this film at all.', shape=(), dtype=string)
# tf.Tensor(0, shape=(), dtype=int64)

前処理

文章中の各単語には、以下の方法で id を割り当てる。

# Preprocessing

def standardization(input_data):
    lowercase = tf.strings.lower(input_data)
    no_tag = tf.strings.regex_replace(input_data,"<[^>]+>","")
    output = tf.strings.regex_replace(no_tag, '[%s]' % re.escape(string.punctuation), '')
    return output

vectorize_layer = TextVectorization(
    standardize=standardization,
    max_tokens=VOCAB_SIZE,
    output_mode='int',
    output_sequence_length=SEQUENCE_LENGTH
)

vectorize_layer.adapt(
    raw_train_dataset.map(lambda text, label: text, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False),
)

def vectorize_text(text, label):
    return vectorize_layer(text), label

train_dataset = raw_train_dataset.map(vectorize_text, num_parallel_calls=tf.data.AUTOTUNE).cache().shuffle(10000).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
val_dataset = raw_val_dataset.map(vectorize_text, num_parallel_calls=tf.data.AUTOTUNE).cache().shuffle(10000).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
test_dataset = raw_test_dataset.map(vectorize_text, num_parallel_calls=tf.data.AUTOTUNE).cache().batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

文章は全て小文字に変換し html タグを取り除き句読点を除く前処理を行う。また語彙の数に制限を作り文章の長さも一定に揃える。

モデルの解説

次のようなモデルを作成した。今回はBidirectionalを用いることで文章を双方向から LSTM に学習させベクトル化する。

# Model

model = Sequential([
    Input(shape = (SEQUENCE_LENGTH,)),
    Embedding(VOCAB_SIZE, EMBEDDING_DIM),
    Bidirectional(LSTM(EMBEDDING_DIM, dropout=0.2)),
    Dense(1, activation='sigmoid')
])
model.summary()

はじめに、Embedding レイヤーをランダムに初期化した場合の結果を示す。

Test accuracy : 0.8256000280380249
Test loss : 0.9147175550460815

word2vec を利用

次に Embedding レイヤーを word2vec で事前に学習したものに置き換える。今回、word2vec の分散表現は gensim のものを用いる。gensim に存在する単語はそのままベクトルを用いて、存在しない単語についてはランダムな値で初期化する。

import numpy as np
import gensim.downloader as api
from gensim.models import Word2Vec

# Pretrained Embeddings
word2vec = api.load('word2vec-google-news-300')
pretrained_embeddings = np.array([
    word2vec[word] if word in word2vec else np.random.normal(loc=0, scale=1, size=(EMBEDDING_DIM, ))
    for i, word in enumerate(vectorize_layer.get_vocabulary())
])

これを用いて次のようなモデルを作成し学習を行なった。

# Model

model2 = Sequential([
    Input(shape = (SEQUENCE_LENGTH,)),
    Embedding(VOCAB_SIZE, EMBEDDING_DIM, embeddings_initializer=tf.keras.initializers.Constant(pretrained_embeddings)),
    Bidirectional(LSTM(EMBEDDING_DIM, dropout=0.2)),
    Dense(1, activation='sigmoid')
])
model2.summary()

結果を以下に示す。

Test accuracy : 0.8596000075340271
Test loss : 0.6397333145141602

最終的な精度はわずかに上昇した。学習初期は word2vec を用いない方が精度が良い。

おわりに

word2vec で事前に学習した分散表現を利用して LSTM のモデルを学習した結果、わずかに精度が向上した。わずかしか精度が向上しなかった理由としては、LSTM は次の単語を予想し word2vec は周囲の単語を当てるというタスクを解くため、これらのタスクが似通っていることが原因と考えられる。これらは似通ったタスクを解くため、自然言語に対する新たな仮定を与えているわけではないから精度も向上しないと考えられる。一方で、少しだけ精度が向上した理由としては、データセット中に殆ど現れない単語について、事前に別のデータセットで学習した分散表現の情報使って初期化しているので、その情報を使って正解を予想しているので精度が上がったと考えられる。ほとんど現れない単語の分散表現は学習が難しいが、そもそもテストデータにもほとんど現れないので、この情報を知ってることによる精度向上はわずかになる。この考察から、LSTM のモデルについて Embedding レイヤーを Word2Vec を用いて学習するのは、学習データが少ししかなく、単語の数も多いので分散表現を 1 から学習することができない場合に有効であると考えられる。実社会ではこういうパターンも多いと思うのでこの手法を使っている人は多そう。

本実験の感想としては、画像処理の Data Augmentation[1]や ResNet を転移学習させる方法[2]などよりは精度の向上がみられなくて残念であった。画像処理の場合はラベルを不変に保つ画像の変換を考え、それを実装することも容易いが、自然言語処理の場合はラベルを普遍に保つような単語や文章の変換は考えづらいし実装も難しいので工夫が難しいなと感じた。だからこそ、自然言語処理の場合は word2vec のような教師なしの学習法があるのかもしれない。

また、TensorFlow には Wrapper[3]や RNN の Base class[4]などの機能があり、これらをカスタマイズすることでオリジナルの RNN のネットワークも作成が可能なことがわかった。RNN は自前で実装するのが難しそうなイメージがあったが、これを使えば楽できそう。

脚注
  1. https://zenn.dev/derbuihan/articles/d856427e52faea ↩︎

  2. https://zenn.dev/derbuihan/articles/cc23e8d3be1570 ↩︎

  3. https://www.tensorflow.org/api_docs/python/tf/keras/layers/Wrapper ↩︎

  4. https://www.tensorflow.org/api_docs/python/tf/keras/layers/RNN ↩︎

Discussion