🦀

Rust axum Websocketを使用したChatの実装

2022/07/01に公開

概要

RustWebSocketを使用したChatの実装を行うにあたり、tokio-rs/axumにてChat部分のExampleを参考にしました。

今回実現したかったこと

LINEのチャットルームのように、該当チャットルームのユーザーにのみリアルタイムでメッセージを配信することです。
Exampleの記事ではWebSocketに接続している全ユーザにメッセージがSendされてしまうため、特定のユーザーのみに配信することは考慮されていませんでした。

tokioのbroadcast(※注1)の利用はそのまま継続し、senderで配信する際にメッセージを送るべきユーザーの分岐を追加することで、特定のユーザーのみへの配信を叶えました。

※注1:tokioのbroadcastはあくまでもプロセス内のブロードキャストであり、実際に全員に配信されるわけではない

Rustでのバックエンド実装

コード内の★部分が該当箇所です。

main.rs
use axum::{
  extract::{
      ws::{Message, WebSocket, WebSocketUpgrade},
      Extension,
  },
  response::{IntoResponse},
  routing::get,
  Router,
};
use futures::{sink::SinkExt, stream::StreamExt};
use std::{
  collections::HashSet,
  net::SocketAddr,
  sync::{Arc, Mutex},
};
use tokio::sync::broadcast;
use serde_json::Value;
use tower_http::add_extension::AddExtensionLayer;
use serde::{Serialize, Deserialize};

#[derive(Debug)]
struct AppState {
  user_id_set: Arc<Mutex<HashSet<String>>>,
  tx: broadcast::Sender<String>,
}

#[tokio::main]
async fn main() {
  // ロギングの設定
  tracing_subscriber::fmt()
      .with_max_level(tracing::Level::DEBUG)
      .init();
  tracing::debug!("this is a tracing line");

  let user_id_set = Arc::new(Mutex::new(HashSet::new()));
  let (tx, _rx) = broadcast::channel(100);

  let app_state = Arc::new(AppState { user_id_set, tx });

  let app = Router::new()
      .route("/websocket", get(websocket_handler))
      .layer(AddExtensionLayer::new(app_state));

  let addr = SocketAddr::from(([127, 0, 0, 1], 3012));

  axum::Server::bind(&addr)
      .serve(app.into_make_service())
      .await
      .unwrap();
}

async fn websocket_handler(
  ws: WebSocketUpgrade,
  Extension(state): Extension<Arc<AppState>>,
) -> impl IntoResponse {
  ws.on_upgrade(|socket| websocket(socket, state))
}

#[derive(Debug, Deserialize, Serialize)]
enum ChatRoomTypeEnum {
  Initial,
  DirectChatRoomId,
  GroupChatRoomId
}

#[derive(Debug, Deserialize, Serialize,PartialEq)]
enum MessageTypeEnum {
  SetUserId,
  SendMessage,
  None
}
#[derive(Debug, Deserialize, Serialize)]
pub struct MessageStruct {
  _id: String,
  chat_room_id: u64,
  chat_room_type: String,
  created_at: String,
  send_user_ids: Vec<String>,
  text: String,
  message_type: MessageTypeEnum,
  user: User,
  user_id: String,
}


#[derive(Debug, Deserialize, Serialize)]
pub struct SendUserIds {
  send_user_ids: Option<Vec<String>>
}

#[derive(Debug, Deserialize, Serialize)]
pub struct User {
  _id: String
}

// メッセージの初回に送信
#[derive(Debug, Deserialize, Serialize)]
pub struct MessageType {
  message_type: String
}

async fn websocket(stream: WebSocket, state: Arc<AppState>) {
  let (mut sender, mut receiver) = stream.split();

  let mut user_id: String = String::new();
  let mut message_type: String = String::new();
  let message_type_enum: MessageTypeEnum;

  while let Some(Ok(message)) = receiver.next().await {
      if let Message::Text(message_text) = message {
          parse_message_type_and_user_id(&state, &mut user_id, &mut message_type, message_text);
          message_type_enum = {
              if message_type == "SetUserId" {
                  MessageTypeEnum::SetUserId
              } else {
                  MessageTypeEnum::None
              }
          };
          if message_type_enum == MessageTypeEnum::SetUserId {
              break;
          } else {
              return;
          }
      }
  }
  let mut rx = state.tx.subscribe();
  
  let message_type = MessageType {
      message_type: message_type.clone()
  };

  let msg = serde_json::to_string(&message_type).unwrap();
  let _ = state.tx.send(msg);
  let mut send_task = {
      let user_id_set = state.user_id_set.lock().unwrap().clone();
      tokio::spawn(async move {
      while let Ok(msg) = rx.recv().await {
                // ★①メッセージを送るべきuser_idをsend_user_ids変数に格納
          // ★②websocketに接続しているuser_id一覧(user_id_set)の中のuser_idとメッセージを送るべきuser_idが一致していれば、メッセージを送信
          // ★③send_user_idのうち、1件のuser_idにメッセージを送信したら、フロントでchat_room_idが一致する全員にwebsocketを配信する
          let send_user_ids: SendUserIds = serde_json::from_str(&msg).unwrap();
          match send_user_ids.send_user_ids {
              Some(ids) => {
                  for send_user_id in ids {
                      if user_id_set.contains(&send_user_id) {
                          let clone_msg = msg.clone();
                          if sender.send(Message::Text(clone_msg)).await.is_err() {
                              break;
                          }
                          // 1度送れば、フロントでchat_room_idが一致する人にsetMessageする
                          break;
                      }
                  }
              },
              None => {
                  if sender.send(Message::Text(msg)).await.is_err() {
                      break;
                  }
              }
          };

      }
  })};

  let tx = state.clone().tx.clone();

  let mut recv_task = tokio::spawn(async move {
      while let Some(Ok(Message::Text(text))) = receiver.next().await {
          let result = parse_result(text).await.unwrap();
          match result {
              Some(res) => {
                  if res.message_type == MessageTypeEnum::SendMessage {
                      let msg = serde_json::to_string(&res).unwrap();
                      let _ = tx.send(msg);
                  }
              }, 
              None => {
              }
          }
      }
  });

  tokio::select! {
      _ = (&mut send_task) => recv_task.abort(),
      _ = (&mut recv_task) => send_task.abort(),
  };
}

async fn parse_result(message_text: String) -> anyhow::Result<Option<MessageStruct>> {
  // 項目
  let mut _id = String::from("");
  let mut chat_room_id:u64 = 0;
  let mut chat_room_type =  String::from("");
  let mut created_at = String::from("");
  let mut message_type:MessageTypeEnum = MessageTypeEnum::None;
  let mut send_user_ids:Vec<String> = vec![];
  let mut text = String::from("");
  let mut user = User {
      _id : String::from("")
  };
  let mut copy_id = String::from("");
  let mut user_id = String::new();

  let messages: Value = serde_json::from_str(&message_text).unwrap();
  let message = messages[0].clone();

  // _idの取り出し
  if let Value::String(id) = &message["_id"] {
      _id.push_str(id.as_str());
      copy_id.push_str(id.as_str());
  }
  
  // chat_room_idの取り出し
  if let Value::Number(chat_room_id_number) = &message["chat_room_id"] {
      match chat_room_id_number.as_u64() {
          Some(num) => {
              chat_room_id = num;
          },
          None => {

          }
      }
  }

  // chat_room_typeの取り出し
  if let Value::String(chat_room_type_string) = &message["chat_room_type"] {
      chat_room_type.push_str(chat_room_type_string);
  }

  // created_atの取り出し
  if let Value::String(created_at_string) = &message["created_at"] {
      created_at.push_str(created_at_string);
  }

  // message_typeの取り出し
  if let Value::String(message_type_string) = &message["message_type"] {
      message_type = {
          if message_type_string == "SendMessage" {
              MessageTypeEnum::SendMessage
          } else {
              MessageTypeEnum::None
          }
      };
  }

  // send_user_idsの取り出し
  if let Value::Array(send_user_ids_vec) = &message["send_user_ids"] {
      for list in send_user_ids_vec {
          if let Value::String(user_id) = list {
              send_user_ids.push(user_id.to_string());
          }
      }
  }

  // textの取り出し
  if let Value::String(text_string) = &message["text"] {
      text.push_str(text_string.as_str());
  }

  // userの取り出し
  if let Value::Object(user_obj) = &message["user"] {
      let user_list = user_obj.get(&String::from("_id"));
      if let Some(sub_user_list) = user_list {
          if let Value::String(main_user_list) = sub_user_list {
              user = User {
                  _id : main_user_list.to_string()
              };
          }
      }
  }

  // user_idの取り出し
  if let Value::String(user_id_string) = &message["user_id"] {
      user_id.push_str(user_id_string.as_str());
  }

  // resultの整形
  let result = MessageStruct {
      _id: _id,
      chat_room_id: chat_room_id,
      chat_room_type:chat_room_type,
      created_at: created_at,
      send_user_ids: send_user_ids,
      text: text,
      message_type: message_type,
      user: user,
      user_id: user_id.clone(),
  };
  return Ok(Some(result))
}

fn parse_message_type_and_user_id(state: &AppState, user_id_text: &mut String, message_type_text: &mut String, message_text: String) {
  let mut user_id = String::new();

  let messages: Value = serde_json::from_str(&message_text).unwrap();
  let message = messages[0].clone();
  // user_idの取り出し
  if let Value::String(user_id_string) = &message["user_id"] {
      user_id.push_str(user_id_string.as_str());
  }

  let mut user_id_set = state.user_id_set.lock().unwrap();

  if !user_id_set.contains(&user_id) {
      user_id_set.insert(user_id.to_owned());

      user_id_text.push_str(&user_id);
  }

  // message_typeの取り出し
  if let Value::String(message_type_string) = &message["message_type"] {
      message_type_text.push_str(&message_type_string);
  }

}
Cargo.toml
[dependencies]
axum = { version = "0.5.9", features = ["ws"] }
futures = "0.3"
tokio = { version = "1", features = ["full"] }
tower = { version = "0.4", features = ["util"] }
tracing = "0.1"
tracing-subscriber = "0.2"
tower-http = {version = "0.3.4", features = ["full"]}
serde_json = "1.0"
serde = { version = "1.0", features = ["derive"] }
anyhow = "1.0"

React Native Expoでのフロント実装(WebSocket導入部分)

chat.tsx
import { sock } from "../../websocket"
...
useEffect(() => {
  const handler = e => {
    const newMessage = JSON.parse(e.data)
    if (isMounted.current) {
      // チャット画面に遷移してきた際にのみ実行
      if (newMessage["message_type"] === "SetUserId") {
        return
      } else {
        // newMessage["message_type"] === "SendMessage"
        // メッセージを送った際に実行
        // ユーザーが開いているチャットルームに一致する場合のみメッセージを表示する
        const messageDirectChatRoomId = newMessage.chat_room_type === "DirectChatRoomId" ? newMessage.chat_room_id : null;
        const messageGroupChatRoomId = newMessage.chat_room_type === "GroupChatRoomId" ? newMessage.chat_room_id : null;
        if ((directChatRoomId !== null && directChatRoomId === messageDirectChatRoomId) || (groupChatRoomId !== null && groupChatRoomId === messageGroupChatRoomId)) {
          setMessages(previousMessages => GiftedChat.append(previousMessages, newMessage))
        }

      }
    }
  }
  sock.addEventListener("message", handler)
  return () => {
    sock.removeEventListener("message", handler)
  }
}, [userId])
websocket.js
// 当記事ではngrokを利用しています
// URLは環境にあったものをご利用ください
export const sock = new WebSocket("wss://d1ea-61-120-204-212.jp.ngrok.io/websocket");

詰まったところ

tokio::spawnの中でMutexlockが出来ない

main.rs
// NG例
let mut send_task =  tokio::spawn(async move {
  let user_id_set = state.user_id_set.lock().unwrap().clone();
  ...
});

理由: std::sync::MutexGuard型がSendではないためエラーが起こる。
mutexのロックを別のスレッドに送ることはできないため、エラーになっている。
解決策: tokio::spawnの外でMutexlockを行う

main.rs
// OK例
let mut send_task = {
  let user_id_set = state.user_id_set.lock().unwrap().clone();
  tokio::spawn(async move {
  ...
})};

Special thanks

Chatの実装やロギングに関して、@megumish_unsafeさんにアドバイスいただきました。教えていただきありがとうございました!!!

Discussion