🐙

GPT-2とGoogle ColabでGradioに入門する

2023/02/19に公開

gradioについて少し知っておこうと考え使ってみたところ、思っていた以上に簡単でした。
超入門的な内容ですが、せっかくなのでGPT-2をgradioを通して動かしてみたという記事です。

環境

Google Colabを利用しています。
ランタイムにはTPUを選択しましたが、GPUでもいいかもしれません。

必要なものをインストール

Google colabで必要なものをインストールします。

!pip install transformers gradio

GPT-2

GPT-2のトークナイザーとモデルをインポートします。

from transformers import GPT2Tokenizer, GPT2LMHeadModel

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")

作成したプロンプトをモデルに入力して、出力を得ます。

text = "I am a"

# プロンプトをトークンにエンコード
input_ids = tokenizer.encode(text, return_tensors="pt")

# モデルに入力して、出力のトークンを取得
output_ids = model.generate(input_ids, do_sample=True, max_length=20, pad_token_id=50256)

# 出力をテキストにデコード
output_text = tokenizer.decode(output_ids[0])
print(output_text)

出力

I am a journalist and you are a writer. What would you be as a writer? Spencer: You're better as a journalist if your main job is to document the way in which people in the general public are reacting. It

デフォルトのままだとmax_lengthの数だけ出力するので、余計な文章も吐き出しますがひとまずは動いてくれます。

Gradio

あとは上記のコードを関数化してgradioのインターフェースに渡してあげるだけです。
gradioのインポートします。

import gradio as gr

先ほどのGPT-2による推論の部分を関数にします。

def inference(prompt):
	input_ids = tokenizer.encode(prompt, return_tensors="pt")
	output_ids = model.generate(input_ids, do_sample=True, max_length=20, pad_token_id=50256)
	output_text = tokenizer.decode(output_ids[0])
	return output_text

gradioのInterfaceに引数として渡して起動します。

demo = gr.Interface(fn=inference,
                    inputs="text",
                    outputs="text")

demo.launch()

次のようなリンクが出力されます。

Running on https://localhost:7863/


非常にお手軽にWebUIを作ることができました。
これはみんなが使うわけですね...

Discussion