🤖

ファインチューニングしたOpenCALMでチャット実装

2023/11/17に公開

前回の記事で作成したファインチューニング済みOpenCALMでチャットを実装しました
https://zenn.dev/tk1/articles/071fc9e5493c66

目次

  1. 概要
  2. フロントエンド
  3. サーバサイド
  4. サーバアプリ
  5. チャット
  6. チャット(OpenCALMChat)

概要

以下の要件で作ってみました

  • ブラウザ上から利用できるチャット
  • ファインチューニングしたモデルを使用
  • 生成中はローディング表示をし、生成した言葉をその都度表示していく
    (生成完了してからまとめて返すのではない)

今回の選定技術は以下の通りです

  • GoogleColab(GPU:V100)
  • 通信プロトコル websocket
  • フロントエンド javascript
  • サーバサイド python(FastAPI)
  • サーバアプリ uvicorn、ngrok

ColabのGPU、本当はA100を使いたいのですがここ最近全く割り振られず…
仕方なくV100です
そのため7Bのモデルはロードできず、今回は1Bのモデルを使ってます

フロントエンド

シンプルに1枚のHTML内でwebsocket通信をします

HTML長いのでこの中

html = """
<!DOCTYPE html>
<html>
  <head>
    <title>OpenCALM chat</title>
    <style>
      #chatArea {width:500px; min-height:300px; margin-left:auto; margin-right:auto; background:#b0d0f0}
      .playerBubbleArea {text-align:right}
      .aiBubbleArea {text-align:left}
      .playerBubble {display:inline-block; max-width:60%; margin:7px 5px; padding:0px 7px; text-align:left;
        background:#d0f0d0; border:solid; border-color:transparent; border-radius:15px}
      .aiBubble {display:inline-block; max-width:60%; margin:7px 5px; padding:0px 7px; text-align:left;
        background:white; border:solid; border-color:transparent; border-radius:15px}
      #chatForm {width:500px; margin:5px auto; text-align:right}
      #chatInput {width:250px; height:20px}
      #sendButton {width:50px; height:26px}
      @keyframes pulse {50% {background:transparent}}
      #loading {position:relative; display:inline-block; width:2px; height:2px; margin:0px 10px;
        background:#909090; border:solid; border-width:2px; border-radius:3px; border-color:transparent;
        animation:pulse 1000ms infinite; animation-delay:250ms;
        &::before, &::after {content:""; position:absolute; display:block; width:2px; height:2px;
          background:#909090; border:solid; border-width:2px; border-radius:3px; border-color:transparent;
          animation:pulse 1000ms infinite}
        &::before {right:7px; top:-2px}
        &::after{left:7px; top:-2px; animation-delay:500ms}}
    </style>
  </head>
  <body>
    <div id="chatArea"></div>
    <form id="chatForm" action="" onsubmit="sendChat(event)">
      <input type="text" id="chatInput" autocomplete="off" />
      <button id="sendButton">↑</button>
    </form>
    <script type="text/javascript">
      var ws = new WebSocket("wss://" + location.host + "/chat")
      var chatInput = document.getElementById('chatInput')
      var sendButton = document.getElementById('sendButton')
      var chatArea = document.getElementById('chatArea')
      var chatList = []
      var loading
      var receivedData = ''
      function sendChat(event) {
        event.preventDefault()
        var data = chatInput.value
        if (sendButton.disabled == false && data != '') {
          sendButton.disabled = true
          var chatBubbleArea = document.createElement('div')
          chatBubbleArea.className = 'playerBubbleArea'
          chatArea.appendChild(chatBubbleArea)
          var chatBubble = document.createElement('div')
          chatBubble.className = 'playerBubble'
          chatBubbleArea.appendChild(chatBubble)
          chatBubble.appendChild(document.createTextNode(data))
          chatList.push(['player', data])
          ws.send(JSON.stringify(arrangeData(chatList)))
          chatInput.value = ''
        }
      }
      ws.onmessage = function(event) {
        var data = JSON.parse(event.data)
        if (data.type == 'start') {
          var chatBubbleArea = document.createElement('div')
          chatBubbleArea.className = 'aiBubbleArea'
          chatArea.appendChild(chatBubbleArea)
          var chatBubble = document.createElement('div')
          chatBubble.className = 'aiBubble'
          chatBubbleArea.appendChild(chatBubble)
          loading = document.createElement('span')
          loading.id = 'loading'
          chatBubble.appendChild(loading)
        } else if (data.type == 'end') {
          loading.remove()
          chatList.push(['ai', receivedData])
          receivedData = ''
          sendButton.disabled = false
        } else if (data.type == 'answer') {
          loading.before(document.createTextNode(data.answer))
          receivedData += data
        }
      }
      function arrangeData(chatList) {
        result = {}
        for (let i=0; i<chatList.length; i++) {
          result[i] = {user:chatList[i][0], words:chatList[i][1]}
        }
        return result
      }
    </script>
  </body>
</html>
"""

長いのでざっくりまとめると…

  • スタイルシートで良さげな見た目と、ローディングアニメーションを作成
  • Javascript上で「var ws = new WebSocket("wss://" + location.host + "/chat")」でwebsocketインスタンス作成
  • ボタンを押すと、sendChatメソッドが実行されて過去の会話+入力内容がjsonで送信される
  • jsonでメッセージを受け取るたびに言葉が追加表示される

です

サーバサイド

まず必要なライブラリをインストールして、前回のモデルを読み込みます

!pip install transformers
!pip install accelerate
!pip install peft

from google.colab import drive
from peft import PeftModel, PeftConfig
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

drive.mount('/content/drive')

filename = '/content/drive/MyDrive/<保存したフォルダ&モデルファイル>'

config = PeftConfig.from_pretrained(filename)
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
model = PeftModel.from_pretrained(model, filename)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)

その後、FastAPIでサーバサイドの処理を実装します

!pip install fastapi nest-asyncio

from fastapi import FastAPI, WebSocket
from fastapi.responses import HTMLResponse
import asyncio
import nest_asyncio
from threading import Thread
from transformers import TextIteratorStreamer

app = FastAPI()

nest_asyncio.apply()

@app.get("/", response_class=HTMLResponse)
async def root():
  return html

@app.websocket("/chat")
async def chat_endpoint(websocket: WebSocket):
  await websocket.accept()
  while True:
    json_obj = await websocket.receive_json()
    prompt = make_prompt(json_obj.values())

    await websocket.send_json(dict(type="start"))

    for answer in generate(prompt):
      #取得できず空白ならforに戻る
      if not answer:
        continue
      await websocket.send_json(dict(type="answer", answer=answer))
      #他スレをブロックしないように
      await asyncio.sleep(0)

    await websocket.send_json(dict(type="end"))

def make_prompt(chatList):
  prompt = ""
  for c in chatList:
    if c["user"] == "player":
      prompt += "Q:" + c["words"] + "\n"
    elif c["user"] == "ai":
      prompt += "A:" + c["words"] + "\n"
  prompt += "A:"
  return prompt

def generate(prompt: str):
  streamer = TextIteratorStreamer(
          tokenizer,
          skip_prompt=True,
          skip_special_tokens=True
      )
  thread = Thread(
      target=model.generate,
      kwargs=dict(
          **tokenizer(prompt, return_tensors="pt").to(model.device),
          streamer=streamer,
          max_new_tokens=256,
          do_sample=True,
          temperature=0.7,
          repetition_penalty=1.25,
          pad_token_id=tokenizer.pad_token_id
      )
  )
  thread.start()
  return streamer

websocket通信内でメッセージを受け取ったら、プロンプトを成形し、modelのgenerateメソッドで生成を始めます
この時TextIteratorStreamerを引数streamerに指定することで、生成途中の出力を得ることができます
なお、生成は非同期で行うので、Threadを使って実行します

サーバアプリ

サーバアプリにはuvicornを使用し、ngrokでトンネリングして外部からアクセスできるようにします

!pip install pyngrok uvicorn[standard]
!ngrok config add-authtoken <ngrok公式サイトにログインして取得した認証トークン文字列>

from pyngrok import ngrok
import uvicorn

ngrok_tunnel = ngrok.connect(8000)
print('Public URL:', ngrok_tunnel.public_url)

uvicorn.run(app, port=8000)

実行後に出力されるURLにアクセスするとチャットが表示されます

チャット


1Bの小さいモデルということもあって、会話はあまり噛み合っていないですが
自分でファインチューニングしたモデルを使って、それっぽいWebUIチャットすることができました!

ちなみにこの開発中にOpenCALMのチャットモデルが出てました
7Bのみで今回試せなかったので、次回は試してみたいです
https://huggingface.co/cyberagent/calm2-7b-chat

チャット(OpenCALMChat)

※2023/11/19追記※
A100が使えたのでOpenCALMのチャットモデルを試してみました

model_name = "cyberagent/calm2-7b-chat"
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(model_name)

モデルをロードしたあとに

def make_prompt(chatList):
  prompt = ""
  for c in chatList:
    if c["user"] == "player":
      prompt += "USER: " + c["words"] + "\n"
    elif c["user"] == "ai":
      prompt += "ASSISTANT: " + c["words"] + "<|endoftext|>\n"
  prompt += "ASSISTANT: "
  return prompt

プロンプト作成メソッドを変えてフォーマットを合わせました

結果

最初は良かったんですが、途中からおかしくなってしまいました
(もしかしたら設定が間違ってた…?)

取り急ぎこんな結果でした

Discussion