🦀
Rust axum Websocketを使用したChatの実装
概要
Rust
でWebSocket
を使用したChat
の実装を行うにあたり、tokio-rs/axumにてChat
部分のExample
を参考にしました。
今回実現したかったこと
LINE
のチャットルームのように、該当チャットルームのユーザーにのみリアルタイムでメッセージを配信することです。
Example
の記事ではWebSocket
に接続している全ユーザにメッセージがSend
されてしまうため、特定のユーザーのみに配信することは考慮されていませんでした。
tokioのbroadcast(※注1)の利用はそのまま継続し、sender
で配信する際にメッセージを送るべきユーザーの分岐を追加することで、特定のユーザーのみへの配信を叶えました。
※注1:tokio
のbroadcastはあくまでもプロセス内のブロードキャストであり、実際に全員に配信されるわけではない
Rustでのバックエンド実装
コード内の★部分が該当箇所です。
main.rs
use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
Extension,
},
response::{IntoResponse},
routing::get,
Router,
};
use futures::{sink::SinkExt, stream::StreamExt};
use std::{
collections::HashSet,
net::SocketAddr,
sync::{Arc, Mutex},
};
use tokio::sync::broadcast;
use serde_json::Value;
use tower_http::add_extension::AddExtensionLayer;
use serde::{Serialize, Deserialize};
#[derive(Debug)]
struct AppState {
user_id_set: Arc<Mutex<HashSet<String>>>,
tx: broadcast::Sender<String>,
}
#[tokio::main]
async fn main() {
// ロギングの設定
tracing_subscriber::fmt()
.with_max_level(tracing::Level::DEBUG)
.init();
tracing::debug!("this is a tracing line");
let user_id_set = Arc::new(Mutex::new(HashSet::new()));
let (tx, _rx) = broadcast::channel(100);
let app_state = Arc::new(AppState { user_id_set, tx });
let app = Router::new()
.route("/websocket", get(websocket_handler))
.layer(AddExtensionLayer::new(app_state));
let addr = SocketAddr::from(([127, 0, 0, 1], 3012));
axum::Server::bind(&addr)
.serve(app.into_make_service())
.await
.unwrap();
}
async fn websocket_handler(
ws: WebSocketUpgrade,
Extension(state): Extension<Arc<AppState>>,
) -> impl IntoResponse {
ws.on_upgrade(|socket| websocket(socket, state))
}
#[derive(Debug, Deserialize, Serialize)]
enum ChatRoomTypeEnum {
Initial,
DirectChatRoomId,
GroupChatRoomId
}
#[derive(Debug, Deserialize, Serialize,PartialEq)]
enum MessageTypeEnum {
SetUserId,
SendMessage,
None
}
#[derive(Debug, Deserialize, Serialize)]
pub struct MessageStruct {
_id: String,
chat_room_id: u64,
chat_room_type: String,
created_at: String,
send_user_ids: Vec<String>,
text: String,
message_type: MessageTypeEnum,
user: User,
user_id: String,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct SendUserIds {
send_user_ids: Option<Vec<String>>
}
#[derive(Debug, Deserialize, Serialize)]
pub struct User {
_id: String
}
// メッセージの初回に送信
#[derive(Debug, Deserialize, Serialize)]
pub struct MessageType {
message_type: String
}
async fn websocket(stream: WebSocket, state: Arc<AppState>) {
let (mut sender, mut receiver) = stream.split();
let mut user_id: String = String::new();
let mut message_type: String = String::new();
let message_type_enum: MessageTypeEnum;
while let Some(Ok(message)) = receiver.next().await {
if let Message::Text(message_text) = message {
parse_message_type_and_user_id(&state, &mut user_id, &mut message_type, message_text);
message_type_enum = {
if message_type == "SetUserId" {
MessageTypeEnum::SetUserId
} else {
MessageTypeEnum::None
}
};
if message_type_enum == MessageTypeEnum::SetUserId {
break;
} else {
return;
}
}
}
let mut rx = state.tx.subscribe();
let message_type = MessageType {
message_type: message_type.clone()
};
let msg = serde_json::to_string(&message_type).unwrap();
let _ = state.tx.send(msg);
let mut send_task = {
let user_id_set = state.user_id_set.lock().unwrap().clone();
tokio::spawn(async move {
while let Ok(msg) = rx.recv().await {
// ★①メッセージを送るべきuser_idをsend_user_ids変数に格納
// ★②websocketに接続しているuser_id一覧(user_id_set)の中のuser_idとメッセージを送るべきuser_idが一致していれば、メッセージを送信
// ★③send_user_idのうち、1件のuser_idにメッセージを送信したら、フロントでchat_room_idが一致する全員にwebsocketを配信する
let send_user_ids: SendUserIds = serde_json::from_str(&msg).unwrap();
match send_user_ids.send_user_ids {
Some(ids) => {
for send_user_id in ids {
if user_id_set.contains(&send_user_id) {
let clone_msg = msg.clone();
if sender.send(Message::Text(clone_msg)).await.is_err() {
break;
}
// 1度送れば、フロントでchat_room_idが一致する人にsetMessageする
break;
}
}
},
None => {
if sender.send(Message::Text(msg)).await.is_err() {
break;
}
}
};
}
})};
let tx = state.clone().tx.clone();
let mut recv_task = tokio::spawn(async move {
while let Some(Ok(Message::Text(text))) = receiver.next().await {
let result = parse_result(text).await.unwrap();
match result {
Some(res) => {
if res.message_type == MessageTypeEnum::SendMessage {
let msg = serde_json::to_string(&res).unwrap();
let _ = tx.send(msg);
}
},
None => {
}
}
}
});
tokio::select! {
_ = (&mut send_task) => recv_task.abort(),
_ = (&mut recv_task) => send_task.abort(),
};
}
async fn parse_result(message_text: String) -> anyhow::Result<Option<MessageStruct>> {
// 項目
let mut _id = String::from("");
let mut chat_room_id:u64 = 0;
let mut chat_room_type = String::from("");
let mut created_at = String::from("");
let mut message_type:MessageTypeEnum = MessageTypeEnum::None;
let mut send_user_ids:Vec<String> = vec![];
let mut text = String::from("");
let mut user = User {
_id : String::from("")
};
let mut copy_id = String::from("");
let mut user_id = String::new();
let messages: Value = serde_json::from_str(&message_text).unwrap();
let message = messages[0].clone();
// _idの取り出し
if let Value::String(id) = &message["_id"] {
_id.push_str(id.as_str());
copy_id.push_str(id.as_str());
}
// chat_room_idの取り出し
if let Value::Number(chat_room_id_number) = &message["chat_room_id"] {
match chat_room_id_number.as_u64() {
Some(num) => {
chat_room_id = num;
},
None => {
}
}
}
// chat_room_typeの取り出し
if let Value::String(chat_room_type_string) = &message["chat_room_type"] {
chat_room_type.push_str(chat_room_type_string);
}
// created_atの取り出し
if let Value::String(created_at_string) = &message["created_at"] {
created_at.push_str(created_at_string);
}
// message_typeの取り出し
if let Value::String(message_type_string) = &message["message_type"] {
message_type = {
if message_type_string == "SendMessage" {
MessageTypeEnum::SendMessage
} else {
MessageTypeEnum::None
}
};
}
// send_user_idsの取り出し
if let Value::Array(send_user_ids_vec) = &message["send_user_ids"] {
for list in send_user_ids_vec {
if let Value::String(user_id) = list {
send_user_ids.push(user_id.to_string());
}
}
}
// textの取り出し
if let Value::String(text_string) = &message["text"] {
text.push_str(text_string.as_str());
}
// userの取り出し
if let Value::Object(user_obj) = &message["user"] {
let user_list = user_obj.get(&String::from("_id"));
if let Some(sub_user_list) = user_list {
if let Value::String(main_user_list) = sub_user_list {
user = User {
_id : main_user_list.to_string()
};
}
}
}
// user_idの取り出し
if let Value::String(user_id_string) = &message["user_id"] {
user_id.push_str(user_id_string.as_str());
}
// resultの整形
let result = MessageStruct {
_id: _id,
chat_room_id: chat_room_id,
chat_room_type:chat_room_type,
created_at: created_at,
send_user_ids: send_user_ids,
text: text,
message_type: message_type,
user: user,
user_id: user_id.clone(),
};
return Ok(Some(result))
}
fn parse_message_type_and_user_id(state: &AppState, user_id_text: &mut String, message_type_text: &mut String, message_text: String) {
let mut user_id = String::new();
let messages: Value = serde_json::from_str(&message_text).unwrap();
let message = messages[0].clone();
// user_idの取り出し
if let Value::String(user_id_string) = &message["user_id"] {
user_id.push_str(user_id_string.as_str());
}
let mut user_id_set = state.user_id_set.lock().unwrap();
if !user_id_set.contains(&user_id) {
user_id_set.insert(user_id.to_owned());
user_id_text.push_str(&user_id);
}
// message_typeの取り出し
if let Value::String(message_type_string) = &message["message_type"] {
message_type_text.push_str(&message_type_string);
}
}
Cargo.toml
[dependencies]
axum = { version = "0.5.9", features = ["ws"] }
futures = "0.3"
tokio = { version = "1", features = ["full"] }
tower = { version = "0.4", features = ["util"] }
tracing = "0.1"
tracing-subscriber = "0.2"
tower-http = {version = "0.3.4", features = ["full"]}
serde_json = "1.0"
serde = { version = "1.0", features = ["derive"] }
anyhow = "1.0"
React Native Expoでのフロント実装(WebSocket導入部分)
chat.tsx
import { sock } from "../../websocket"
...
useEffect(() => {
const handler = e => {
const newMessage = JSON.parse(e.data)
if (isMounted.current) {
// チャット画面に遷移してきた際にのみ実行
if (newMessage["message_type"] === "SetUserId") {
return
} else {
// newMessage["message_type"] === "SendMessage"
// メッセージを送った際に実行
// ユーザーが開いているチャットルームに一致する場合のみメッセージを表示する
const messageDirectChatRoomId = newMessage.chat_room_type === "DirectChatRoomId" ? newMessage.chat_room_id : null;
const messageGroupChatRoomId = newMessage.chat_room_type === "GroupChatRoomId" ? newMessage.chat_room_id : null;
if ((directChatRoomId !== null && directChatRoomId === messageDirectChatRoomId) || (groupChatRoomId !== null && groupChatRoomId === messageGroupChatRoomId)) {
setMessages(previousMessages => GiftedChat.append(previousMessages, newMessage))
}
}
}
}
sock.addEventListener("message", handler)
return () => {
sock.removeEventListener("message", handler)
}
}, [userId])
websocket.js
// 当記事ではngrokを利用しています
// URLは環境にあったものをご利用ください
export const sock = new WebSocket("wss://d1ea-61-120-204-212.jp.ngrok.io/websocket");
詰まったところ
①tokio::spawn
の中でMutex
のlock
が出来ない
main.rs
// NG例
let mut send_task = tokio::spawn(async move {
let user_id_set = state.user_id_set.lock().unwrap().clone();
...
});
理由: std::sync::MutexGuard
型がSend
ではないためエラーが起こる。
mutex
のロックを別のスレッドに送ることはできないため、エラーになっている。
解決策: tokio::spawn
の外でMutex
のlock
を行う
main.rs
// OK例
let mut send_task = {
let user_id_set = state.user_id_set.lock().unwrap().clone();
tokio::spawn(async move {
...
})};
Special thanks
Chat
の実装やロギングに関して、@megumish_unsafeさんにアドバイスいただきました。教えていただきありがとうございました!!!
Discussion