rust + axum(0.6系) 本番サービス開発中の知見メモ
rustのweb framework Axumを使って本番サービスを開発するにあたり、
得た知見をメモしていきます。(まだまだ分かって無いことも多いので、コメント歓迎です)
(設計方針として、モノリシック寄り、Rustの恩恵を受けつつ、スピード感を重視した開発)
参考にさせていただいた記事やコード
axum のver 0.7 への migrationはこちら
axum中で利用しているhyperなどのモジュール1.0対応した axum ver0.7へのmigrationを実施した記録はこちら
handler内でエラー伝搬を使いたい
シンプルな実装としては、こんな感じで、routerのメソッドに渡す関数helloは、axum::response::IntoResponse
traitを実装した型を返す。
参照
実装
let api_routes = Router::new()
.route("/", get(controller::hello))
pub struct SimpleJson {
pub data: String,
}
pub async fn hello() -> impl IntoResponse {
Json(SimpleJson {
data: "okです。".into(),
})
}
では、hello内の例外処理はどうする?
素直に用意されてるコード使うなら、こんな感じ
pub struct SimpleJson {
pub data: String,
}
pub async fn hello() -> Result<impl IntoResponse, impl IntoResponse> {
let res = rand::random::<bool>();
if res {
Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(SimpleJson {
data: "エラー発生".into(),
}),
))
} else {
Ok((
StatusCode::OK,
Json(SimpleJson {
data: "okです。".into(),
}),
))
}
}
でもこれだと、エラー伝搬の"?" が使えないのでちょっと不便。
そこで、IntoResponseを実装したオリジナルErrorを作成して利用する。
参考:
実装例
(エラー伝搬の場合は強制的にInternal Server Errorを返す。Statusを自分で実装したい場合は、ApiErrorを自分で作って返す必要あり)
pub struct ApiError {
status: StatusCode,
response: Json<serde_json::Value>,
}
//ApiErrorには、Into<anyhow::Error>を実装しているエラー(基本的に全てのエラー)から変換できるようにしておく
impl<E> From<E> for ApiError
where
E: Into<anyhow::Error>,
{
fn from(original_error: E) -> Self {
Self {
status: StatusCode::INTERNAL_SERVER_ERROR,
response: axum::Json(serde_json::json!({
"error": format!("{:#?}", original_error.into())
})),
}
}
}
//ApiErrorは、Responseへの変換を行えるようにIntoResponseを実装しておく
impl IntoResponse for ApiError {
fn into_response(self) -> axum::response::Response {
(self.status, self.response).into_response()
}
}
//handlerメソッドの実装例
pub async fn hello() -> anyhow::Result<impl IntoResponse, ApiError> {
//Resultを返す何らかの処理。エラーだった場合はApiErrorに伝搬して終了。
try_something()?;
let res = rand::random::<bool>();
if res {
Err(ApiError {
status: StatusCode::BAD_REQUEST,
response: Json(serde_json::json!("エラーです。")),
}))
} else {
Ok((
StatusCode::OK,
Json(SimpleJson {
data: "okです。".into(),
}),
))
}
}
ちょっとリファクタリング
+ pub type ApiResult<T> = anyhow::Result<T, ApiError>;
- pub async fn hello() -> anyhow::Result<impl IntoResponse, ApiError> {
+ pub async fn hello() -> ApiResult<impl IntoResponse> {
JWTによる認証の実装
参考サイト
方針
jsonwebtoken crateを利用する。
Claims structを、jsonwebtoken::encode/decode を使ってシリアライズ化し、client<- ->server間でやり取りする。
まずは各種準備。
use jsonwebtoken::{encode, DecodingKey, EncodingKey, Header};
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Claims {
pub user_id: String,
exp: i64,
}
fn generate_jwt(user_id: String, key: &EncodingKey) -> anyhow::Result<String> {
let exp = (chrono::Utc::now() + chrono::Duration::days(30)).timestamp();
let claims = Claims { user_id, exp };
let token = encode(&Header::default(), &claims, key)?;
Ok(token)
}
fn verify_jwt(token: &str, key: &DecodingKey) -> anyhow::Result<Claims> {
let header = jsonwebtoken::decode_header(token)?;
let claims =
jsonwebtoken::decode::<Claims>(token, key, &jsonwebtoken::Validation::new(header.alg))?
.claims;
Ok(claims)
}
上記を使って、jwtの作成と、jwtを使った認証を行うhandlerを実装する
static KEYS: Lazy<Keys> = Lazy::new(|| {
let secret = env::var("JWT_SECRET").expect("JWT_SECRET must be set");
Keys::new(secret.as_bytes())
});
struct Keys {
encoding: EncodingKey,
decoding: DecodingKey,
}
impl Keys {
fn new(secret: &[u8]) -> Self {
Self {
encoding: EncodingKey::from_secret(secret),
decoding: DecodingKey::from_secret(secret),
}
}
}
//ユーザー登録し、jwt tokenを取得する handler
pub async fn registration(
Json(input): Json<serde_json::Value>,
) -> ApiResult<impl IntoResponse> {
let user_id = regist_user(input) //(省略。ユーザーデータをDBなどに登録して永続化し、idを取得する感じ)
let token = generate_jwt(user_id.to_string(), &KEYS.encoding)?;
//jwt tokenを返す
Ok(Json(serde_json::json!({ "token": token })))
}
//認証付きのrouter handler
pub async fn need_authenticated_action(
Extension(ref claims): Extension<Claims>,
) -> ApiResult<impl IntoResponse> {
Ok(Json(SimpleJson {
data: "okです".into(),
}))
}
認証処理の実装と、Claimsをhanlderに渡す方法
上記の認証が必要なneed_authenticated_actionは、
すでにrouterの設定の中で認証処理を行い、Claimsを取得している。
このhanlderに、Claimsを渡す方法はいくつかあるのだが、
比較的分かりやすく、かつ認証処理の入れ忘れをしづらい方法を採用。
middleware::from_fn(authenticate)をrouter layerとして挟むことで、そこにnestedされたhandlerをcallするときは、必ずauthenticateを呼び出してくれるようになる。
このようにhandlerの前にlayer(middleware)を挟む処理は、型を持たない他言語のweb frameworkでもよく使われる。
let mut router = Router::new();
// path が "/api " で始まるこれらのAPIは、認証不要なAPI群として設計する
let api_routes = Router::new()
.route("/", get(controller::top))
.route("/exam_list", get(exam::exam_list))
.route("/exam_one", get(exam::exam_one));
router = router.merge(Router::new().nest("/api", api_routes));
// path が "/api/user " で始まるAPIは、必ず認証が必要なAPI群として設計する
let user_api_routes = Router::new().route("/need_authenticated_action", get(need_authenticated_action));
router = router.merge(Router::new().nest("/api/user",
user_api_routes.route_layer(middleware::from_fn(auth::authenticate))));
authenticateにて、Authorization headerにセットされたjwt token文字列を取得し、
それをdecodeしてClaims structを復元する。
requestのextensionにclaimsをセットしておくことで、ここを通過した後のhandlerでは、
Extension(claims)を扱うことができるようになる。
pub async fn authenticate<B>(mut req: Request<B>, next: Next<B>) -> Result<Response, StatusCode> {
// extract the authorization header
let auth_header = req
.headers()
.get(axum::http::header::AUTHORIZATION)
.and_then(|header| header.to_str().ok());
let auth_header = if let Some(auth_header) = auth_header {
auth_header
} else {
return Err(StatusCode::UNAUTHORIZED);
};
if let Ok(claims) = verify_jwt(auth_header, &KEYS.decoding) {
// insert the current user into a request extension so the handler can
// extract it
println!("auth claims: {:?}", claims.clone());
req.extensions_mut().insert(claims);
Ok(next.run(req).await)
} else {
Err(StatusCode::UNAUTHORIZED)
}
}
認証付きhandler 他の実装方法
axum公式のsampleのように、
layerを使わず、Claimsに axum::extract::FromRequestParts traitを実装することで、
自動的にAuthorizationヘッダーからデータ取得・認証処理を行った上でClaims structに変換することもできる。
この場合、認証の有無をhandler側で付けるしかなく、
router側で一括してつけられない(工夫すればできるか?)なので、
認証の付け忘れが発生するリスクを回避するため、今回は採用しなかった。
FromRequestPartsを使う例
async fn need_authenticated_action(claims: Claims) -> Result<String, AuthError> {
// Send the protected data to the user
Ok(format!(
"Welcome to the protected area :)\nYour data:\n{}",
claims
))
}
#[async_trait]
impl<S> FromRequestParts<S> for Claims
where
S: Send + Sync,
{
type Rejection = AuthError;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
// Extract the token from the authorization header
let TypedHeader(Authorization(bearer)) = parts
.extract::<TypedHeader<Authorization<Bearer>>>()
.await
.map_err(|_| AuthError::InvalidToken)?;
// Decode the user data
let token_data = decode::<Claims>(bearer.token(), &KEYS.decoding, &Validation::default())
.map_err(|_| AuthError::InvalidToken)?;
Ok(token_data.claims)
}
}
注釈
Extensionに渡せる型は、上記のClaimのように、Clone型をimplementしていれば、他の型でもlayer内で渡せる。
Database accessのためのインスタンスをどのように持ち回るか?
handler内だけで使うのであれば、以下のようにrouterを作るときにExtensionで持たせてあげて、
handlerで使うのが一番分かりやすい。
mongoを使う例
let mongo_uri = env::var("MONGO_URL").unwrap();
let mongo_database_name = env::var("MONGO_DB_NAME").unwrap();
let mongo_client = Client::with_uri_str(&mongo_uri).await.unwrap();
let mongo_database = lient.database(&mongo_database_name);
let app = router.layer(
ServiceBuilder::new()
.layer(HandleErrorLayer::new(|error: BoxError| async move {
if error.is::<tower::timeout::error::Elapsed>() {
Ok(StatusCode::REQUEST_TIMEOUT)
} else {
Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
))
}
}))
.timeout(Duration::from_secs(10))
.layer(TraceLayer::new_for_http())
.layer(Extension(mongo_database))
.into_inner(),
);
pub async fn handler(
Extension(ref mongo_database): Extension<mongodb::Database>,
Json(input): Json<serde_json::Value>,
) -> impl IntoResponse {
let body_doc = bson::ser::to_document(&input).unwrap();
let data_collection = mongo_database.collection::<Document>("data");
data_collection.insert_one(body_doc, None);
しかし、handler外で使うときに不便なことがある。
例えば、権限が必要なAPIへのアクセス時、handlerに処理を渡す前に、authenticattionで認証認可のチェックを行う必要があるとき、
以下のようにmiddlewareのfrom_fnで実装する場合、うまく渡せない。
(渡せそうな感じもあるんですが・・・詳しい方教えてください)
※自己解決しました。↓追記した ※解決 を参照ください。
router.route_layer(middleware::from_fn(auth::authenticate))
pub async fn authenticate<B>(mut req: Request<B>, next: Next<B>) -> Result<Response, StatusCode> {
let mongo_database = ???????????????????????;
// extract the authorization header
let auth_header = req
.headers()
.get(axum::http::header::AUTHORIZATION)
.and_then(|header| header.to_str().ok());
let auth_header = if let Some(auth_header) = auth_header {
auth_header
} else {
return Err(StatusCode::UNAUTHORIZED);
};
if let Ok(claims) = verify_jwt(auth_header, &KEYS.decoding) {
// insert the current user into a request extension so the handler can
// extract it
println!("auth claims: {:?}", claims.clone());
let user_collection = mongo_database.get().unwrap().collection::<Document>("user");
if let Ok(oid) = claims.user_id.parse::<mongodb::bson::oid::ObjectId>() {
let filter = doc! {
"_id": oid,
};
let found_doc_option = user_collection
.find_one(
filter, // filter2,
None,
)
.await;
};
req.extensions_mut().insert(claims);
Ok(next.run(req).await)
} else {
Err(StatusCode::UNAUTHORIZED)
}
}
なので、handler以外の場所で使うために、
結局 Lazyでstaticにもたせる事で持ち回ることに。
pub static MONGO_DATABASE: OnceCell<mongodb::Database> = OnceCell::new();
pub async fn create_database() {
// Get a handle to a collection in the database.
let mongo_uri = env::var("MONGO_URL").unwrap();
let mongo_database = env::var("MONGO_DB_NAME").unwrap();
let client = Client::with_uri_str(&mongo_uri).await.unwrap();
MONGO_DATABASE.set(client.database(&mongo_database));
}
pub async fn authenticate<B>(mut req: Request<B>, next: Next<B>) -> Result<Response, StatusCode> {
let mongo_database = &crate::infrastructure::mongo::MONGO_DATABASE;
Extensionは扱いやすいので、これはこれで使う。
let mongo_database: &mongodb::Database = mongo::MONGO_DATABASE.get().unwrap();
let app = router.layer(
ServiceBuilder::new()
.layer(HandleErrorLayer::new(|error: BoxError| async move {
if error.is::<tower::timeout::error::Elapsed>() {
Ok(StatusCode::REQUEST_TIMEOUT)
} else {
Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {}", error),
))
}
}))
.timeout(Duration::from_secs(10))
.layer(TraceLayer::new_for_http())
.layer(Extension(mongo_database))
.layer(CompressionLayer::new())
.into_inner(),
);
pub async fn handler(
Extension(mongo_database): Extension<&mongodb::Database>,
) -> impl IntoResponse {
let exam_collection = mongo_database.collection::<Document>("exam");
let cursor = exam_collection.find(None, None).await.unwrap();
※解決
上記の問題自己解決しました。 axumのlayerは、callした順番を逆にたどって呼ばれていきます。
したがって、後に追加したlayerでセットしたExtensionは、その前に追加したlayer内で利用することができます。
例: 認証が必要なapiの実装
認証が必要なapiは、authentication系のlayerを呼び出すのが一般的です。
このlayer内で、databaseにアクセスして、tokenをもとに認証情報が正しいか否か、チェックするとします。
ここでconnectionインスタンスが必要になりますが、
authentication layerの後で、connectionインスタンスをExtensionでセットするlayerを仕込んでおけば、
authentication layer内でconnectionが呼び出せるわけです。
pub async fn authenticate<B>(mut req: Request<B>, next: Next<B>) -> Result<Response, StatusCode> {
let database = req.extensions().get::<Database>().ok_or_else(|| StatusCode::UNAUTHORIZED)?;
/* databaseを使ってDBの処理 */
余談 Extensionに入れたインスタンスはどのように使われる?
Extensionに、参照ではなく実態をセットしたインスタンスは、どのように使われるのかが気になりました。
具体的にいうと、Extensionでセットしたdatabaseが並列処理の中で複数のapi callが行われたとき、
waitが発生すると困るなと思い調べてみたところ、
毎回cloneされて使われていました。
Cloneを自分で実装してログを挟み、どのタイミングで呼ばれるかをチェック
#[derive(Debug)]
pub struct Database {
pub database: mongodb::Database,
pub auth_org: AuthorizedOrganization,
}
impl Clone for Database {
fn clone(&self) -> Self {
println!("call database clone!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!");
Database{database: self.database.clone(), auth_org: self.auth_org.clone()}
}
}
api 処理中のpanic handling
panic発生時、適切なresponseを返すよう、handlingしたい。
結論としては、
tower_http::catch_panic::CatchPanicLayerを使う。
ResponseForPanicには、デフォルトの CatchPanicLayer::new() も用意されているが、
自分で実装する場合の例はこんな感じ。
use tower_http::catch_panic::{CatchPanicLayer, ResponseForPanic};
// Compose the routes
let app = router.layer(
ServiceBuilder::new()
.layer(CatchPanicLayer::custom(SamplePanicHandler{}))
// .layer(CatchPanicLayer::new())
SamplePanicHandlerの実装例
use http_body::Full;
use tower_http::catch_panic::{CatchPanicLayer, ResponseForPanic};
use axum::response::{IntoResponse, Response},
#[derive(Debug, Default, Clone, Copy)]
struct SamplePanicHandler {}
impl ResponseForPanic for PanicHandler {
type ResponseBody = Full<bytes::Bytes>;
fn response_for_panic(
&mut self,
err: Box<dyn std::any::Any + Send + 'static>,
) -> hyper::Response<Self::ResponseBody> {
let mut error_msg = String::new();
if let Some(s) = err.downcast_ref::<String>() {
tracing::error!("Service panicked: {}", s);
} else if let Some(s) = err.downcast_ref::<&str>() {
tracing::error!("Service panicked: {}", s);
} else {
tracing::error!(
"Service panicked but `CatchPanic` was unable to downcast the panic info"
);
};
let mut res = Response::new(Full::from("Service panicked."));
res.headers_mut().insert(
header::CONTENT_TYPE,
header::HeaderValue::from_static("text/plain; charset=utf-8"),
);
// res.(StatusCode::INTERNAL_SERVER_ERROR);
*res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
res
}
}
もっとシンプルにも書ける。
(CatchPanicLayer::customの引数の型Tは、where T: ResponseForPanic だけど、関数も指定できるようになってる)
use tower_http::catch_panic::{CatchPanicLayer, ResponseForPanic};
// Compose the routes
let app = router.layer(
ServiceBuilder::new()
.layer(CatchPanicLayer::custom(handle_panic))
use axum::response::{IntoResponse, Response},
fn handle_panic(err: Box<dyn std::any::Any + Send + 'static>) -> axum::response::Response<Body> {
let details = if let Some(s) = err.downcast_ref::<String>() {
s.clone()
} else if let Some(s) = err.downcast_ref::<&str>() {
s.to_string()
} else {
"Unknown panic message".to_string()
};
let body = serde_json::json!({
"error": {
"kind": "panic",
"details": details,
}
});
let body = serde_json::to_string(&body).unwrap();
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(body))
.unwrap()
}
reqwestの responseからaxumのresponseへ変換したい
axum apiをproxyとして経由し、別URLにアクセスする場合など。
impl IntoResponse を満たす返り値にするため、
(http::status::StatusCode, http::header::HeaderMap, bytes::Bytes)
を返す
※axum_coreの pub trait IntoResponse 参照
pub async fn sample_api(
Extension(mongo_database): Extension<&mongodb::Database>,
Query(query): Query<serde_json::Value>,
) -> ApiResult<impl IntoResponse> {
let res = reqwest::Client::builder().build().unwrap()
.get(
"https://xxxxx.xxxxx.com/"
)
.query(&query)
.timeout(Duration::from_secs(45))
.send()
.await?;
let status_code = res.status();
Ok((
status_code,
res.headers().clone(),
res.bytes().await.unwrap(),
))
}
axumでSSE(Server Send Event)実装
axumで、SSEのサーバー側実装のやり方。
client側の接続が切れた時の対応を含めた実装方法はこちら。
DropをimplementしたGuard structを作成して、streamを返す処理を書く。
こうすることで、client側の接続が切れると自動的にdrop functionが呼ばれるので
clientの接続状態を管理したりできる。
参照:
use axum::response::sse::{Event, Sse, KeepAlive};
async fn sse_handler() -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
struct Guard {
// whatever state you need here
}
impl Drop for Guard {
fn drop(&mut self) {
tracing::info!("stream closed");
}
}
let stream = async_stream::stream! {
let _guard = Guard {};
let mut interval = tokio::time::interval(Duration::from_secs(1));
loop {
interval.tick().await;
yield Ok(Event::default().data("hi"));
}
// `_guard` is dropped
};
Sse::new(stream)
}
axumでget apiとして設定
.route("/sse_handler", get(connection::sse_handler))
rustでclient側実装する場合
rustでclient側実装例はこちら。
こちらのcrateを利用
こちらの例ではErr時にes.close()しているが、
closeしなければ、自動的に再接続しに行ってくれます。
let mut es =reqwest_eventsource:: EventSource::get("http://localhost:8080/sse_handler");
while let Some(event) = es.next().await {
match event {
Ok(reqwest_eventsource::Event::Open) => println!("Connection Open!"),
Ok(reqwest_eventsource::Event::Message(message)) => {
println!("Message: {:#?}", message)
},
Err(err) => {
println!("Error: {}", err);
es.close();
}
}
}
リアルタイム接続通信はwebsocket が有名ですが、
こちらも、とても簡単な印象です。
加えて、従来のhttp接続なので、firewallなどの環境による影響が少ないのがメリットです。
接続数が大きくなった場合のサーバー側の制限などには要注意。
注意点
chromeのsseでテストしたところ、
axumの上記実装で、接続はできるのですが、データの受信がなぜかリアルタイムにできませんでした。
rustのsse clientとの通信は問題なく動く事を確認したので、もしかするとchromeとの通信に必要なヘッダーがsse responseに不足している等、
chrome側の要件を満たしていない可能性があります。
(深掘りする時間が無いので、取り急ぎこの問題は保留に)
→
自己解決しました。
もともと、
.layer(CompressionLayer::new())
で、すべての処理に対して圧縮をかけていたことで、ブラウザはsseのeventSourceが認識できなくなっていた。
そこで、圧縮対象からsseの処理(headerのcontent-typeがtext/event-streamの場合)を除外することで、問題発生しなくなりました。
use tower_http::compression::{
Compression, CompressionLayer
predicate::{Predicate, NotForContentType, DefaultPredicate},
};
let app = Router::new()
.route("/sse", get(sse_handler))
.layer(CompressionLayer::new().compress_when(DefaultPredicate::new()
.and(NotForContentType::new("text/event-stream"))))
詳細はこちらでやり取りしています
さらに、sse通信を行う際、中継点でbuffering(メッセージをある程度溜めて一括処理する仕組み)されて、
リアルタイムでメッセージが流れてこない場合があります。
例:proxyにnginxを使っている場合
自動的にbufferingされてしまうので、
nginxがbufferingされないよう、ヘッダーにx-accel-buffering=noを追加する必要が有ります。
axum側では、以下のように対応できます。
※axum のsseモジュールは、headerを編集する機能が無いので、
responseを取得して自分でheaderを追加しています。
さらに、axumのcontrollerは、impl IntoResponseである必要があり、responseをそのまま返せないので、
IntoResponseを実装したwrapperを使っています。(もっといいやり方あったらコメントください)
struct ResponseWrapper{
response:axum::response::Response,
}
impl IntoResponse for ResponseWrapper {
fn into_response(self) -> axum::response::Response {
self.response
}
}
.route("/machine_connect_status", get(controller::sse_connect))
pub async fn sse_connect(
Query(query): Query<serde_json::Value>,
// ) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
) -> ApiResult<impl IntoResponse> {
// make stream
// let stream = ...(省略)
let mut res = Sse::new(stream).keep_alive(
axum::response::sse::KeepAlive::default()
.interval(Duration::from_secs(1)),
).into_response();
let mut headers = res.headers_mut();
headers.insert(hyper::http::header::CONNECTION, hyper::http::HeaderValue::from_static("keep-alive"));
// HeaderName::from_staticは、lowercaseでないとエラーになるっぽいので注意
headers.insert(hyper::http::header::HeaderName::from_static("x-accel-buffering"), hyper::http::HeaderValue::from_static("no")); //for nginx
Ok(ResponseWrapper{response: res})
)
axum serverでwebsocket実装
websocketのserver側実装例です。
仕様として、clientからのメッセージが1分間受信されないか(pingなどを常時打つことで死活確認を行っています)、
closeメッセージを受信したら、websocketの受信・送信streamを終了します。
公式URLには、timeout処理が記載されていなかったので、tokio::time::timeout
を使って実装してみました。
公式
.route("/websocket_connection", get(controller::ws_machine_connect_status))
#[tracing::instrument]
pub async fn ws_machine_connect_status(
ws: WebSocketUpgrade,
Query(query): Query<serde_json::Value>,) -> ApiResult<impl IntoResponse> {
let machine_id = query.get("machine_id").unwrap().as_str().unwrap().to_string();
Ok(ResponseWrapper{response: ws.on_upgrade(|socket| ws_machine_connect_status_socket(socket, machine_id))})
}
async fn ws_machine_connect_status_socket(socket: WebSocket, machine_id: String) {
let (mut sender, mut receiver) = socket.split();
let sender_handle = tokio::spawn(write(sender, machine_id.clone()));
let reciever_handle = tokio::spawn(read(receiver, sender_handle.abort_handle(), machine_id.clone()));
}
async fn read(mut receiver: SplitStream<WebSocket>, sender_handle: tokio::task::AbortHandle, machine_id: String) {
//60秒pingが無ければ、接続が切れたと判定して処理を終了する
let timeout = Duration::from_secs(60);
while let result = tokio::time::timeout(timeout, receiver.next()).await {
match result {
// A new message has been received
Ok(Some(Ok(message))) => match message {
axum::extract::ws::Message::Close(frame) => {
break;
}
_ => {
println!("処理を継続");
}
},
// An error occurred while trying to read a message
Ok(Some(Err(_e))) => {
println!("_e: {:?}", _e);
break;
}
Ok(None) => {
println!("receive nothing.");
}
// Timeout occurred
Err(_e ) => {
println!("timeout! {:?}", _e);
break;
}
}
}
//receive処理を抜けるタイミングで、sender側の処理も終了する
sender_handle.abort();
}
async fn write(mut sender: SplitSink<WebSocket, Message>, machine_id: String) {
let mut interval = tokio::time::interval(Duration::from_secs(1));
loop {
interval.tick().await;
let res = sender.send(Message::Text(format!("hello, machine_id:{:?}", machine_id.as_str()))).await;
}
}
ファイルアップロードに関する実装例
axumの、multipart featrueを利用します
axum = {version="0.6.18", features = ["multipart"]}
ファイル上限サイズの設定
use axum::extract::DefaultBodyLimit;
// ........
.route("/upload", post(controller::upload));
.layer(DefaultBodyLimit::max(2 * 1000 * 1000 * 1000))//max length: 2GB
// ........
use axum::extract::Multipart;
pub async fn upload(
mut multipart: Multipart
) -> ApiResult<impl IntoResponse> {
while let Some(mut field) = multipart.next_field().await.unwrap() {
let name = field.name().unwrap().to_string();
let data = field.bytes().await.unwrap();
println!("Length of `{}` is {} bytes", name, data.len());
}
Ok(())
}
ファイルダウンロードに関する実装例(WIP)
関連資料
#[tracing::instrument]
#[debug_handler]
pub async fn get_uploaded_file(
Query(query): Query<serde_json::Value>,
) -> ApiResult<impl IntoResponse> {
let id: String = query.get("id").unwrap().to_string();
if let Some((file_name, temp_file)) = store::pop_uploaded_file(id.as_str()) {
let tmp_file_stream = tokio::fs::File::open(temp_file).await.unwrap();
let stream = tokio_util::io::ReaderStream::new(tmp_file_stream);
let body = hyper::Body::wrap_stream(stream);
let mut response = Response::builder().body(boxed(body)).unwrap();
let mut headers = response.headers_mut();
headers.insert(
hyper::header::CONTENT_TYPE,
hyper::header::HeaderValue::from_static("application/octet-stream"),
);
headers.insert(
hyper::header::CONTENT_DISPOSITION,
hyper::header::HeaderValue::from_str(format!("attachment; filename=\"{}\"", file_name).as_str()).unwrap(),
);
Ok(ResponseWrapper{response})
} else {
Ok(ResponseWrapper{response: Json(serde_json::json!({"message": "error"})).into_response()})
}
}
handlerのエラー事例①
#↓ここでコンパイルエラー
.route(
"/password_reset",
post(controller::user::password_reset),
),
306 | post(controller::user::password_reset),
| ---- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ the trait `Handler<_, _, _>` is not implemented for fn item `fn(Extension<db::Database>, axum::Json<serde_json::Value>) -> impl futures_util::Future<Output = Result<impl IntoResponse, ApiError>> {password_reset}`
| |
| required by a bound introduced by this call
|
= help: the following other types implement trait `Handler<T, S, B>`:
<MethodRouter<S, B> as Handler<(), S, B>>
<axum::handler::Layered<L, H, T, S, B, B2> as Handler<T, S, B2>>
note: required by a bound in `post`
--> /home/masatoyuna/.cargo/registry/src/index.crates.io-6f17d22bba15001f/axum-0.6.20/src/routing/method_routing.rs:407:1
|
#[tracing::instrument]
pub async fn password_reset(
Extension(ref database): Extension<crate::infrastructure::db::Database>,
Json(post_data): Json<serde_json::Value>,
) -> ApiResult<impl IntoResponse> {
// ランダムなバイト列を生成(value)
let mut rng = rand::thread_rng();
let value_bytes: [u8; 32] = rng.gen();
// トークン生成
let hash = gwutil::crypt::hash_password(value_bytes)?;
// 省略 /////////////////////////////////////
Ok(Json(serde_json::json!({ "result": true })))
}
原因
rand crateの
rand::thread_rng()部分が原因。
thread_rngは、システムからsedを得て、遅延初期化されたスレッドローカルの乱数生成器を取得し、ThreadRngインスタンスを返す。
このThreadRngが、Rcを含むため(スレッドセーフでない)、非同期関数の中で使われることによりコンパイルエラーになっている(?)っぽい。エラーメッセージから判断が難しいのでchatGPTと相談しながら突き止めた。
pub struct ThreadRng {
// Rc is explicitly !Send and !Sync
rng: Rc<UnsafeCell<ReseedingRng<Core, OsRng>>>,
}
対策
ランダム生成時、rand::thread_rng()の代わりに rand::StdRng::from_entropy(); を使う。
use rand::{Rng, rngs::StdRng, SeedableRng};
#[tracing::instrument]
pub async fn password_reset(
Extension(ref database): Extension<crate::infrastructure::db::Database>,
Json(post_data): Json<serde_json::Value>,
) -> ApiResult<impl IntoResponse> {
// ランダムなバイト列を生成(value)
let mut rng = StdRng::from_entropy();
let value_bytes: [u8; 32] = rng.gen();
// トークン生成
let hash = gwutil::crypt::hash_password(value_bytes)?;
// 省略 /////////////////////////////////////
Ok(Json(serde_json::json!({ "result": true })))
}
ServiceBuilderのlayerは、通常の逆になる
routerにlayerでチェインしていくと、後ろから読み込まれていく
router
.layer(CatchPanicLayer::custom(handle_panic)) //layer1
.layer(middleware::from_fn(
controller::layer::tracing::request_id,
)) //layer2
ServiceBuilderでlayerをチェインしていくと、頭から読み込まれていく
router.layer(
ServiceBuilder::new()
.layer(CatchPanicLayer::custom(handle_panic)) //layer1
.layer(middleware::from_fn(
controller::layer::tracing::request_id,
)) //layer2
...
)
なので、例えばextensionに値をセットする場合の順序に気をつける必要がある。