🔖

【webアプリ】fastAPIとReactを使ったwebsocketの雛形

に公開

自分でよく使う雛形をまとめました。

React側

  • TailwindCssとdaisyuiを使います
import React, { useState, useEffect } from 'react'

export const Home = () => {

  const [data, setData] = useState();
  const [isConnected, setIsConnected] = useState(false);
  const ws = React.useRef<WebSocket | null>(null);

  useEffect(() => {
    // WebSocketの接続を確立
    ws.current = new WebSocket('ws://localhost:8080/ws');
    
    // 接続が開いたときのイベントハンドラ
    ws.current.onopen = () => {
      console.log('WebSocket接続が確立されました');
      setIsConnected(true);
    };
    
    // メッセージを受信したときのイベントハンドラ
    ws.current.onmessage = (event) => {
      const data = JSON.parse(event.data);
      setData(data);
    };
    
    // エラーが発生したときのイベントハンドラ
    ws.current.onerror = (error) => {
      console.error('WebSocketエラー:', error);
    };
    
    // 接続が閉じたときのイベントハンドラ
    ws.current.onclose = () => {
      console.log('WebSocket接続が閉じられました');
      setIsConnected(false);
    };
    
    // コンポーネントのアンマウント時にWebSocket接続を閉じる
    return () => {
      if (ws.current) {
        ws.current.close();
      }
    };
  }, []);

  const start = async () => {
    const response = await fetch('http://localhost:8080/start?count=10')
    const data = await response.json()
    console.log(data)
  }

  const reset = async () => {
    const response = await fetch('http://localhost:8080/reset')
    const data = await response.json()
    console.log(data)
  }

  return (
    <div>
      <div className='text-2xl font-bold'>ウェブソケット接続テスト</div>

      <div className='flex gap-4'>
          <div className='p-4'>
              <div>
                接続状態: {isConnected ? <span className=' text-accent'>接続中</span> : <span className='text-warning'>未接続</span>}
              </div>

              <div className=' card-actions'>
                <button className='btn btn-primary' onClick={start}>スタート</button>
                <button className='btn btn-secondary' onClick={reset}>リセット</button>                
              </div>
          </div>

          <div className='flex-1'>
            <pre className=' mockup-code'> {JSON.stringify(data, null, 2)}</pre>
          </div>
      </div>

    </div>
  )
}

FastAPI側

import json
import asyncio

from langchain_core.messages import HumanMessage, BaseMessage

from fastapi import FastAPI, WebSocket, Request, Response, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware


app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.get("/")
async def root():
    return {"message": "Hello World"}


from dataclasses import dataclass, field
from typing import List

@dataclass
class Sample:
    messages: List[BaseMessage] = field(default_factory=list)
    count: int = 0

    def to_dict(self):
        return {
            "messages": [msg.content for msg in self.messages],
            "count": self.count
        }

    def reset(self):
        self.messages = []
        self.count = 0

sample = Sample()

# sampleをカウントアップする
@app.get("/start")
async def start(count:int):
    for i in range(count):
        sample.messages.append(HumanMessage(content=f"こんにちは{i}"))
        sample.count += 1
        await asyncio.sleep(0.5)
    
    return Response(content=json.dumps(sample.to_dict()), media_type="application/json")

# sampleをリセットする
@app.get("/reset")
async def reset():
    sample.reset()
    return Response(content=json.dumps(sample.to_dict()), media_type="application/json")

# websocketでデータを送信
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    await websocket.accept()
    try:
        while True:            
            # 現在の状態を送信
            current_state = {
                "messages": [msg.content for msg in sample.messages],
                "count": sample.count
            }

            await websocket.send_json(current_state)
    except WebSocketDisconnect:
        print("Client disconnected")

if __name__ == "__main__":
    import uvicorn
    import os
    file_name = os.path.basename(__file__).replace(".py", "")
    uvicorn.run(f"{file_name}:app", host="0.0.0.0", port=8080, log_level="info", reload=True)

成果物

こんな感じで、サーバー側のsampleの状態をフロント側でリアルタイムに確認することができます

Discussion