ファインチューニングしたOpenCALMでチャット実装
前回の記事で作成したファインチューニング済みOpenCALMでチャットを実装しました
目次
- 概要
- フロントエンド
- サーバサイド
- サーバアプリ
- チャット
- チャット(OpenCALMChat)
概要
以下の要件で作ってみました
- ブラウザ上から利用できるチャット
- ファインチューニングしたモデルを使用
- 生成中はローディング表示をし、生成した言葉をその都度表示していく
(生成完了してからまとめて返すのではない)
今回の選定技術は以下の通りです
- GoogleColab(GPU:V100)
- 通信プロトコル websocket
- フロントエンド javascript
- サーバサイド python(FastAPI)
- サーバアプリ uvicorn、ngrok
ColabのGPU、本当はA100を使いたいのですがここ最近全く割り振られず…
仕方なくV100です
そのため7Bのモデルはロードできず、今回は1Bのモデルを使ってます
フロントエンド
シンプルに1枚のHTML内でwebsocket通信をします
HTML長いのでこの中
html = """
<!DOCTYPE html>
<html>
<head>
<title>OpenCALM chat</title>
<style>
#chatArea {width:500px; min-height:300px; margin-left:auto; margin-right:auto; background:#b0d0f0}
.playerBubbleArea {text-align:right}
.aiBubbleArea {text-align:left}
.playerBubble {display:inline-block; max-width:60%; margin:7px 5px; padding:0px 7px; text-align:left;
background:#d0f0d0; border:solid; border-color:transparent; border-radius:15px}
.aiBubble {display:inline-block; max-width:60%; margin:7px 5px; padding:0px 7px; text-align:left;
background:white; border:solid; border-color:transparent; border-radius:15px}
#chatForm {width:500px; margin:5px auto; text-align:right}
#chatInput {width:250px; height:20px}
#sendButton {width:50px; height:26px}
@keyframes pulse {50% {background:transparent}}
#loading {position:relative; display:inline-block; width:2px; height:2px; margin:0px 10px;
background:#909090; border:solid; border-width:2px; border-radius:3px; border-color:transparent;
animation:pulse 1000ms infinite; animation-delay:250ms;
&::before, &::after {content:""; position:absolute; display:block; width:2px; height:2px;
background:#909090; border:solid; border-width:2px; border-radius:3px; border-color:transparent;
animation:pulse 1000ms infinite}
&::before {right:7px; top:-2px}
&::after{left:7px; top:-2px; animation-delay:500ms}}
</style>
</head>
<body>
<div id="chatArea"></div>
<form id="chatForm" action="" onsubmit="sendChat(event)">
<input type="text" id="chatInput" autocomplete="off" />
<button id="sendButton">↑</button>
</form>
<script type="text/javascript">
var ws = new WebSocket("wss://" + location.host + "/chat")
var chatInput = document.getElementById('chatInput')
var sendButton = document.getElementById('sendButton')
var chatArea = document.getElementById('chatArea')
var chatList = []
var loading
var receivedData = ''
function sendChat(event) {
event.preventDefault()
var data = chatInput.value
if (sendButton.disabled == false && data != '') {
sendButton.disabled = true
var chatBubbleArea = document.createElement('div')
chatBubbleArea.className = 'playerBubbleArea'
chatArea.appendChild(chatBubbleArea)
var chatBubble = document.createElement('div')
chatBubble.className = 'playerBubble'
chatBubbleArea.appendChild(chatBubble)
chatBubble.appendChild(document.createTextNode(data))
chatList.push(['player', data])
ws.send(JSON.stringify(arrangeData(chatList)))
chatInput.value = ''
}
}
ws.onmessage = function(event) {
var data = JSON.parse(event.data)
if (data.type == 'start') {
var chatBubbleArea = document.createElement('div')
chatBubbleArea.className = 'aiBubbleArea'
chatArea.appendChild(chatBubbleArea)
var chatBubble = document.createElement('div')
chatBubble.className = 'aiBubble'
chatBubbleArea.appendChild(chatBubble)
loading = document.createElement('span')
loading.id = 'loading'
chatBubble.appendChild(loading)
} else if (data.type == 'end') {
loading.remove()
chatList.push(['ai', receivedData])
receivedData = ''
sendButton.disabled = false
} else if (data.type == 'answer') {
loading.before(document.createTextNode(data.answer))
receivedData += data
}
}
function arrangeData(chatList) {
result = {}
for (let i=0; i<chatList.length; i++) {
result[i] = {user:chatList[i][0], words:chatList[i][1]}
}
return result
}
</script>
</body>
</html>
"""
長いのでざっくりまとめると…
- スタイルシートで良さげな見た目と、ローディングアニメーションを作成
- Javascript上で「var ws = new WebSocket("wss://" + location.host + "/chat")」でwebsocketインスタンス作成
- ボタンを押すと、sendChatメソッドが実行されて過去の会話+入力内容がjsonで送信される
- jsonでメッセージを受け取るたびに言葉が追加表示される
です
サーバサイド
まず必要なライブラリをインストールして、前回のモデルを読み込みます
!pip install transformers
!pip install accelerate
!pip install peft
from google.colab import drive
from peft import PeftModel, PeftConfig
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
drive.mount('/content/drive')
filename = '/content/drive/MyDrive/<保存したフォルダ&モデルファイル>'
config = PeftConfig.from_pretrained(filename)
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
model = PeftModel.from_pretrained(model, filename)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
その後、FastAPIでサーバサイドの処理を実装します
!pip install fastapi nest-asyncio
from fastapi import FastAPI, WebSocket
from fastapi.responses import HTMLResponse
import asyncio
import nest_asyncio
from threading import Thread
from transformers import TextIteratorStreamer
app = FastAPI()
nest_asyncio.apply()
@app.get("/", response_class=HTMLResponse)
async def root():
return html
@app.websocket("/chat")
async def chat_endpoint(websocket: WebSocket):
await websocket.accept()
while True:
json_obj = await websocket.receive_json()
prompt = make_prompt(json_obj.values())
await websocket.send_json(dict(type="start"))
for answer in generate(prompt):
#取得できず空白ならforに戻る
if not answer:
continue
await websocket.send_json(dict(type="answer", answer=answer))
#他スレをブロックしないように
await asyncio.sleep(0)
await websocket.send_json(dict(type="end"))
def make_prompt(chatList):
prompt = ""
for c in chatList:
if c["user"] == "player":
prompt += "Q:" + c["words"] + "\n"
elif c["user"] == "ai":
prompt += "A:" + c["words"] + "\n"
prompt += "A:"
return prompt
def generate(prompt: str):
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True
)
thread = Thread(
target=model.generate,
kwargs=dict(
**tokenizer(prompt, return_tensors="pt").to(model.device),
streamer=streamer,
max_new_tokens=256,
do_sample=True,
temperature=0.7,
repetition_penalty=1.25,
pad_token_id=tokenizer.pad_token_id
)
)
thread.start()
return streamer
websocket通信内でメッセージを受け取ったら、プロンプトを成形し、modelのgenerateメソッドで生成を始めます
この時TextIteratorStreamerを引数streamerに指定することで、生成途中の出力を得ることができます
なお、生成は非同期で行うので、Threadを使って実行します
サーバアプリ
サーバアプリにはuvicornを使用し、ngrokでトンネリングして外部からアクセスできるようにします
!pip install pyngrok uvicorn[standard]
!ngrok config add-authtoken <ngrok公式サイトにログインして取得した認証トークン文字列>
from pyngrok import ngrok
import uvicorn
ngrok_tunnel = ngrok.connect(8000)
print('Public URL:', ngrok_tunnel.public_url)
uvicorn.run(app, port=8000)
実行後に出力されるURLにアクセスするとチャットが表示されます
チャット
1Bの小さいモデルということもあって、会話はあまり噛み合っていないですが
自分でファインチューニングしたモデルを使って、それっぽいWebUIチャットすることができました!
ちなみにこの開発中にOpenCALMのチャットモデルが出てました
7Bのみで今回試せなかったので、次回は試してみたいです
チャット(OpenCALMChat)
※2023/11/19追記※
A100が使えたのでOpenCALMのチャットモデルを試してみました
model_name = "cyberagent/calm2-7b-chat"
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(model_name)
モデルをロードしたあとに
def make_prompt(chatList):
prompt = ""
for c in chatList:
if c["user"] == "player":
prompt += "USER: " + c["words"] + "\n"
elif c["user"] == "ai":
prompt += "ASSISTANT: " + c["words"] + "<|endoftext|>\n"
prompt += "ASSISTANT: "
return prompt
プロンプト作成メソッドを変えてフォーマットを合わせました
結果
最初は良かったんですが、途中からおかしくなってしまいました
(もしかしたら設定が間違ってた…?)
取り急ぎこんな結果でした
Discussion