🤖

日本語RoBERTaをGoogle Colabで試す

3 min read

BEATを改良したRoBERTaのモデルをオープンソースで利用できるようです。

https://corp.rinna.co.jp/news/2021-8-25pressrelease/

このモデルでは、文章内でマスクされた部分を単語の前後から推論できるみたいです。
面白そう😁

Colabで動かしてみます。

簡単にRoBERTaを動かす

HuggingFaceにある、日本語RoBERTaのプロジェクトページにAPIがあるので簡単に試せます。

https://huggingface.co/rinna/japanese-roberta-base

Huggingface Inference API only supports typing [MASK] in the input string and produces less robust predictions.

APIでは若干精度低いらしい。

MASK文章

[CLS]ヤマザキパンの祭りは[MASK]に開かれる。

結果

パン祭りは日曜日開催!

Google Colabで動かす

ほとんどHuggingFaceの記事通り。

https://huggingface.co/rinna/japanese-roberta-base

試したColab ノートブック

https://colab.research.google.com/drive/1lbZYhhlh8wUpuuegwks5a-vkUiQbpAlA?usp=sharing

ざっくりとコード

# Huggingface Transformersのインストール
!pip install transformers

# Sentencepieceのインストール
!pip install sentencepiece

text_origにマスク入り文章を入力。

from transformers import T5Tokenizer, RobertaForMaskedLM

tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-roberta-base")
tokenizer.do_lower_case = True  # due to some bug of tokenizer config loading

model = RobertaForMaskedLM.from_pretrained("rinna/japanese-roberta-base")

# original text
text_orig = "ヤマザキパンの祭りは[MASK]に開かれる。"
# text_orig = "くお~!!ぶつかる~!!ここでアクセル全開、[MASK]を右に!"
# text = "4年に1度オリンピックは開かれる。"

# prepend [CLS]
text = "[CLS]" + text_orig

# tokenize
tokens = tokenizer.tokenize(text)
print(tokens)

print('mask index :' , tokens.index('[MASK]'))
# tokens.index('オリンピック')

mask indexが取れるので、masked_idxに設定。

# mask a token
masked_idx = 9 # ここにMASKのindexを入れる

推論する。

tokens[masked_idx] = tokenizer.mask_token
# print(tokens)  # output: ['[CLS]', '▁4', '年に', '1', '度', '[MASK]', 'は', '開かれる', '。']

# convert to ids
token_ids = tokenizer.convert_tokens_to_ids(tokens)
# print(token_ids)  # output: [4, 1602, 44, 24, 368, 6, 11, 21583, 8]

# convert to tensor
import torch
token_tensor = torch.tensor([token_ids])

# get the top 10 predictions of the masked token
model = model.eval()
with torch.no_grad():
    outputs = model(token_tensor)
    predictions = outputs[0][0, masked_idx].topk(10)

print(text_orig)

for i, index_t in enumerate(predictions.indices):
    index = index_t.item()
    token = tokenizer.convert_ids_to_tokens([index])[0]
    print(i, token)

結果

毎週金曜日開催!
実際APIより良さそうな候補が出てます。

終わりに

ド素人なんでMASK文章の補完技術が、実際どのように応用するのかあまり思いつきません。
特定の分野の書籍をファインチューニングさせて、「一番優れた方法は[MASK]である。」とか専門家チャットボットとして利用するとか?

(BERTですが)応用的に使ってるのは、センター試験の英語を解いてる記事が面白かったです。

https://zenn.dev/hellorusk/articles/f9e6c503dc54e2

あとはHuggingFace初めて知りましたが、APIで簡単に試せたりして便利ですね。
Google Colabとの連携も簡単でした。

Discussion

ログインするとコメントできます