Open12

rust + axum(0.6系) 本番サービス開発中の知見メモ

yunayuna

rustのweb framework Axumを使って本番サービスを開発するにあたり、
得た知見をメモしていきます。(まだまだ分かって無いことも多いので、コメント歓迎です)

(設計方針として、モノリシック寄り、Rustの恩恵を受けつつ、スピード感を重視した開発)

https://docs.rs/axum/latest/axum/

参考にさせていただいた記事やコード
https://zenn.dev/hkdord/articles/axum-request-id
https://zenn.dev/tanakh/articles/realworld-zum-yew-shuttle
https://github.com/tanakh/axum-yew-shuttle-realworld-example-app

axum のver 0.7 への migrationはこちら

axum中で利用しているhyperなどのモジュール1.0対応した axum ver0.7へのmigrationを実施した記録はこちら
https://zenn.dev/myuna/scraps/fbb044ccfcaea2

yunayuna

handler内でエラー伝搬を使いたい

シンプルな実装としては、こんな感じで、routerのメソッドに渡す関数helloは、axum::response::IntoResponse traitを実装した型を返す。

参照
https://github.com/tokio-rs/axum/blob/main/axum-core/src/response/into_response.rs#L126

実装

router.rs
let api_routes = Router::new()
        .route("/", get(controller::hello))
controller.rs
pub struct SimpleJson {
    pub data: String,
}

pub async fn hello() -> impl IntoResponse {
    Json(SimpleJson {
        data: "okです。".into(),
    })
}

では、hello内の例外処理はどうする?
素直に用意されてるコード使うなら、こんな感じ

controller.rs
pub struct SimpleJson {
    pub data: String,
}

pub async fn hello() -> Result<impl IntoResponse, impl IntoResponse> {
    let res = rand::random::<bool>();
    if res {
        Err((
            StatusCode::INTERNAL_SERVER_ERROR,
            Json(SimpleJson {
                data: "エラー発生".into(),
            }),
        ))
    } else {
        Ok((
            StatusCode::OK,
            Json(SimpleJson {
                data: "okです。".into(),
            }),
        ))
    }
}

でもこれだと、エラー伝搬の"?" が使えないのでちょっと不便。
そこで、IntoResponseを実装したオリジナルErrorを作成して利用する。

参考:
https://github.com/tokio-rs/axum/blob/main/examples/anyhow-error-response/src/main.rs

実装例

(エラー伝搬の場合は強制的にInternal Server Errorを返す。Statusを自分で実装したい場合は、ApiErrorを自分で作って返す必要あり)

controller.rs
pub struct ApiError {
    status: StatusCode,
    response: Json<serde_json::Value>,
}

//ApiErrorには、Into<anyhow::Error>を実装しているエラー(基本的に全てのエラー)から変換できるようにしておく
impl<E> From<E> for ApiError
where
    E: Into<anyhow::Error>,
{
    fn from(original_error: E) -> Self {
        Self {
            status: StatusCode::INTERNAL_SERVER_ERROR,
            response: axum::Json(serde_json::json!({
                "error": format!("{:#?}", original_error.into())
            })),
        }
    }
}

//ApiErrorは、Responseへの変換を行えるようにIntoResponseを実装しておく
impl IntoResponse for ApiError {
    fn into_response(self) -> axum::response::Response {
        (self.status, self.response).into_response()
    }
}

//handlerメソッドの実装例
pub async fn hello() -> anyhow::Result<impl IntoResponse, ApiError> {
    //Resultを返す何らかの処理。エラーだった場合はApiErrorに伝搬して終了。
    try_something()?;

    let res = rand::random::<bool>();
    if res {
        Err(ApiError {
            status: StatusCode::BAD_REQUEST,
            response: Json(serde_json::json!("エラーです。")),
        }))
    } else {
        Ok((
            StatusCode::OK,
            Json(SimpleJson {
                data: "okです。".into(),
            }),
        ))
    }
}

ちょっとリファクタリング

controller.rs
+ pub type ApiResult<T> = anyhow::Result<T, ApiError>;


- pub async fn hello() -> anyhow::Result<impl IntoResponse, ApiError> {
+ pub async fn hello() -> ApiResult<impl IntoResponse> {
yunayuna

JWTによる認証の実装

参考サイト
https://github.com/tokio-rs/axum/blob/main/examples/jwt/src/main.rs

https://github.com/tanakh/axum-yew-shuttle-realworld-example-app/blob/2083ffe59b5d580115abfd13483b3f742a3fabf5/backend/src/auth.rs

方針

jsonwebtoken crateを利用する。
Claims structを、jsonwebtoken::encode/decode を使ってシリアライズ化し、client<- ->server間でやり取りする。

まずは各種準備。

use jsonwebtoken::{encode,  DecodingKey, EncodingKey, Header};

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Claims {
    pub user_id: String,
    exp: i64,
}

fn generate_jwt(user_id: String, key: &EncodingKey) -> anyhow::Result<String> {
    let exp = (chrono::Utc::now() + chrono::Duration::days(30)).timestamp();

    let claims = Claims { user_id, exp };

    let token = encode(&Header::default(), &claims, key)?;

    Ok(token)
}
fn verify_jwt(token: &str, key: &DecodingKey) -> anyhow::Result<Claims> {
    let header = jsonwebtoken::decode_header(token)?;

    let claims =
        jsonwebtoken::decode::<Claims>(token, key, &jsonwebtoken::Validation::new(header.alg))?
            .claims;
    Ok(claims)
}

上記を使って、jwtの作成と、jwtを使った認証を行うhandlerを実装する


static KEYS: Lazy<Keys> = Lazy::new(|| {
    let secret = env::var("JWT_SECRET").expect("JWT_SECRET must be set");
    Keys::new(secret.as_bytes())
});

struct Keys {
    encoding: EncodingKey,
    decoding: DecodingKey,
}

impl Keys {
    fn new(secret: &[u8]) -> Self {
        Self {
            encoding: EncodingKey::from_secret(secret),
            decoding: DecodingKey::from_secret(secret),
        }
    }
}


//ユーザー登録し、jwt tokenを取得する handler
pub async fn registration(
    Json(input): Json<serde_json::Value>,
) -> ApiResult<impl IntoResponse> {
    let user_id = regist_user(input) //(省略。ユーザーデータをDBなどに登録して永続化し、idを取得する感じ)
    let token = generate_jwt(user_id.to_string(), &KEYS.encoding)?;

    //jwt tokenを返す
    Ok(Json(serde_json::json!({ "token": token })))
}

//認証付きのrouter handler
pub async fn need_authenticated_action(
    Extension(ref claims): Extension<Claims>,
) -> ApiResult<impl IntoResponse> {
  Ok(Json(SimpleJson {
            data: "okです".into(),
        }))
}

認証処理の実装と、Claimsをhanlderに渡す方法

上記の認証が必要なneed_authenticated_actionは、
すでにrouterの設定の中で認証処理を行い、Claimsを取得している。
このhanlderに、Claimsを渡す方法はいくつかあるのだが、
比較的分かりやすく、かつ認証処理の入れ忘れをしづらい方法を採用。

middleware::from_fn(authenticate)をrouter layerとして挟むことで、そこにnestedされたhandlerをcallするときは、必ずauthenticateを呼び出してくれるようになる。

このようにhandlerの前にlayer(middleware)を挟む処理は、型を持たない他言語のweb frameworkでもよく使われる。


let mut router = Router::new();

    // path が "/api " で始まるこれらのAPIは、認証不要なAPI群として設計する
    let api_routes = Router::new()
        .route("/", get(controller::top))
        .route("/exam_list", get(exam::exam_list))
        .route("/exam_one", get(exam::exam_one));

    router = router.merge(Router::new().nest("/api", api_routes));


    // path が "/api/user " で始まるAPIは、必ず認証が必要なAPI群として設計する
    let user_api_routes = Router::new().route("/need_authenticated_action", get(need_authenticated_action));

    router = router.merge(Router::new().nest("/api/user", 
         user_api_routes.route_layer(middleware::from_fn(auth::authenticate))));

authenticateにて、Authorization headerにセットされたjwt token文字列を取得し、
それをdecodeしてClaims structを復元する。
requestのextensionにclaimsをセットしておくことで、ここを通過した後のhandlerでは、
Extension(claims)を扱うことができるようになる。

pub async fn authenticate<B>(mut req: Request<B>, next: Next<B>) -> Result<Response, StatusCode> {
    // extract the authorization header
    let auth_header = req
        .headers()
        .get(axum::http::header::AUTHORIZATION)
        .and_then(|header| header.to_str().ok());

    let auth_header = if let Some(auth_header) = auth_header {
        auth_header
    } else {
        return Err(StatusCode::UNAUTHORIZED);
    };

    if let Ok(claims) = verify_jwt(auth_header, &KEYS.decoding) {
        // insert the current user into a request extension so the handler can
        // extract it
        println!("auth claims: {:?}", claims.clone());
        req.extensions_mut().insert(claims);

        Ok(next.run(req).await)
    } else {
        Err(StatusCode::UNAUTHORIZED)
    }
}

認証付きhandler 他の実装方法

axum公式のsampleのように、
https://github.com/tokio-rs/axum/blob/main/examples/jwt/src/main.rs

layerを使わず、Claimsに axum::extract::FromRequestParts traitを実装することで、
自動的にAuthorizationヘッダーからデータ取得・認証処理を行った上でClaims structに変換することもできる。

この場合、認証の有無をhandler側で付けるしかなく、
router側で一括してつけられない(工夫すればできるか?)なので、
認証の付け忘れが発生するリスクを回避するため、今回は採用しなかった。

FromRequestPartsを使う例

async fn need_authenticated_action(claims: Claims) -> Result<String, AuthError> {
    // Send the protected data to the user
    Ok(format!(
        "Welcome to the protected area :)\nYour data:\n{}",
        claims
    ))
}

#[async_trait]
impl<S> FromRequestParts<S> for Claims
where
    S: Send + Sync,
{
    type Rejection = AuthError;

    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
        // Extract the token from the authorization header
        let TypedHeader(Authorization(bearer)) = parts
            .extract::<TypedHeader<Authorization<Bearer>>>()
            .await
            .map_err(|_| AuthError::InvalidToken)?;
        // Decode the user data
        let token_data = decode::<Claims>(bearer.token(), &KEYS.decoding, &Validation::default())
            .map_err(|_| AuthError::InvalidToken)?;

        Ok(token_data.claims)
    }
}

注釈

Extensionに渡せる型は、上記のClaimのように、Clone型をimplementしていれば、他の型でもlayer内で渡せる。

yunayuna

Database accessのためのインスタンスをどのように持ち回るか?

handler内だけで使うのであれば、以下のようにrouterを作るときにExtensionで持たせてあげて、
handlerで使うのが一番分かりやすい。

mongoを使う例

router.rs
let mongo_uri = env::var("MONGO_URL").unwrap();
let mongo_database_name = env::var("MONGO_DB_NAME").unwrap();
let mongo_client = Client::with_uri_str(&mongo_uri).await.unwrap();
let mongo_database = lient.database(&mongo_database_name);

let app = router.layer(
        ServiceBuilder::new()
            .layer(HandleErrorLayer::new(|error: BoxError| async move {
                if error.is::<tower::timeout::error::Elapsed>() {
                    Ok(StatusCode::REQUEST_TIMEOUT)
                } else {
                    Err((
                        StatusCode::INTERNAL_SERVER_ERROR,
                        format!("Unhandled internal error: {}", error),
                    ))
                }
            }))
            .timeout(Duration::from_secs(10))
            .layer(TraceLayer::new_for_http())
            .layer(Extension(mongo_database))
            .into_inner(),
    );
handler.rs
pub async fn handler(
    Extension(ref mongo_database): Extension<mongodb::Database>,
    Json(input): Json<serde_json::Value>,
) -> impl IntoResponse {
    let body_doc = bson::ser::to_document(&input).unwrap();
    let data_collection = mongo_database.collection::<Document>("data");
    data_collection.insert_one(body_doc, None);

しかし、handler外で使うときに不便なことがある。
例えば、権限が必要なAPIへのアクセス時、handlerに処理を渡す前に、authenticattionで認証認可のチェックを行う必要があるとき、
以下のようにmiddlewareのfrom_fnで実装する場合、うまく渡せない。
(渡せそうな感じもあるんですが・・・詳しい方教えてください)

※自己解決しました。↓追記した ※解決 を参照ください。

router.rs
router.route_layer(middleware::from_fn(auth::authenticate))
auth.rs
pub async fn authenticate<B>(mut req: Request<B>, next: Next<B>) -> Result<Response, StatusCode> {
    let mongo_database = ???????????????????????;

    // extract the authorization header
    let auth_header = req
        .headers()
        .get(axum::http::header::AUTHORIZATION)
        .and_then(|header| header.to_str().ok());

    let auth_header = if let Some(auth_header) = auth_header {
        auth_header
    } else {
        return Err(StatusCode::UNAUTHORIZED);
    };

    if let Ok(claims) = verify_jwt(auth_header, &KEYS.decoding) {
        // insert the current user into a request extension so the handler can
        // extract it
        println!("auth claims: {:?}", claims.clone());

        let user_collection = mongo_database.get().unwrap().collection::<Document>("user");
        if let Ok(oid) = claims.user_id.parse::<mongodb::bson::oid::ObjectId>() {
            let filter = doc! {
                "_id": oid,
            };

            let found_doc_option = user_collection
                .find_one(
                    filter, // filter2,
                    None,
                )
                .await;

        };

        req.extensions_mut().insert(claims);

        Ok(next.run(req).await)
    } else {
        Err(StatusCode::UNAUTHORIZED)
    }
}

なので、handler以外の場所で使うために、
結局 Lazyでstaticにもたせる事で持ち回ることに。

mongo.rs
pub static MONGO_DATABASE: OnceCell<mongodb::Database> = OnceCell::new();

pub async fn create_database() {
    // Get a handle to a collection in the database.
    let mongo_uri = env::var("MONGO_URL").unwrap();
    let mongo_database = env::var("MONGO_DB_NAME").unwrap();

    let client = Client::with_uri_str(&mongo_uri).await.unwrap();
    MONGO_DATABASE.set(client.database(&mongo_database));
}
auth.rs
pub async fn authenticate<B>(mut req: Request<B>, next: Next<B>) -> Result<Response, StatusCode> {
     let mongo_database = &crate::infrastructure::mongo::MONGO_DATABASE;

Extensionは扱いやすいので、これはこれで使う。


let mongo_database: &mongodb::Database = mongo::MONGO_DATABASE.get().unwrap();

let app = router.layer(
        ServiceBuilder::new()
            .layer(HandleErrorLayer::new(|error: BoxError| async move {
                if error.is::<tower::timeout::error::Elapsed>() {
                    Ok(StatusCode::REQUEST_TIMEOUT)
                } else {
                    Err((
                        StatusCode::INTERNAL_SERVER_ERROR,
                        format!("Unhandled internal error: {}", error),
                    ))
                }
            }))
            .timeout(Duration::from_secs(10))
            .layer(TraceLayer::new_for_http())
            .layer(Extension(mongo_database))
            .layer(CompressionLayer::new())
            .into_inner(),
    );
handler.rs

pub async fn handler(
    Extension(mongo_database): Extension<&mongodb::Database>,
) -> impl IntoResponse {
    let exam_collection = mongo_database.collection::<Document>("exam");
    let cursor = exam_collection.find(None, None).await.unwrap();

※解決

上記の問題自己解決しました。 axumのlayerは、callした順番を逆にたどって呼ばれていきます。
したがって、後に追加したlayerでセットしたExtensionは、その前に追加したlayer内で利用することができます。

例: 認証が必要なapiの実装

認証が必要なapiは、authentication系のlayerを呼び出すのが一般的です。
このlayer内で、databaseにアクセスして、tokenをもとに認証情報が正しいか否か、チェックするとします。
ここでconnectionインスタンスが必要になりますが、
authentication layerの後で、connectionインスタンスをExtensionでセットするlayerを仕込んでおけば、
authentication layer内でconnectionが呼び出せるわけです。

auth.rs
pub async fn authenticate<B>(mut req: Request<B>, next: Next<B>) -> Result<Response, StatusCode> {
    let database = req.extensions().get::<Database>().ok_or_else(|| StatusCode::UNAUTHORIZED)?;
    /* databaseを使ってDBの処理 */

余談 Extensionに入れたインスタンスはどのように使われる?

Extensionに、参照ではなく実態をセットしたインスタンスは、どのように使われるのかが気になりました。

具体的にいうと、Extensionでセットしたdatabaseが並列処理の中で複数のapi callが行われたとき、
waitが発生すると困るなと思い調べてみたところ、
毎回cloneされて使われていました。

Cloneを自分で実装してログを挟み、どのタイミングで呼ばれるかをチェック

#[derive(Debug)]
pub struct Database {
    pub database: mongodb::Database,
    pub auth_org: AuthorizedOrganization,
}
impl Clone for Database {
    fn clone(&self) -> Self {
        println!("call database clone!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!");
        Database{database: self.database.clone(), auth_org: self.auth_org.clone()}
    }
}
yunayuna

api 処理中のpanic handling

panic発生時、適切なresponseを返すよう、handlingしたい。

結論としては、
tower_http::catch_panic::CatchPanicLayerを使う。
ResponseForPanicには、デフォルトの CatchPanicLayer::new() も用意されているが、
自分で実装する場合の例はこんな感じ。

use tower_http::catch_panic::{CatchPanicLayer, ResponseForPanic};

    // Compose the routes
    let app = router.layer(
        ServiceBuilder::new()
            .layer(CatchPanicLayer::custom(SamplePanicHandler{}))
            // .layer(CatchPanicLayer::new())

SamplePanicHandlerの実装例

use http_body::Full;
use tower_http::catch_panic::{CatchPanicLayer, ResponseForPanic};
use axum::response::{IntoResponse, Response},

#[derive(Debug, Default, Clone, Copy)]
struct SamplePanicHandler {}
impl ResponseForPanic for PanicHandler {
    type ResponseBody = Full<bytes::Bytes>;

    fn response_for_panic(
        &mut self,
        err: Box<dyn std::any::Any + Send + 'static>,
    ) -> hyper::Response<Self::ResponseBody> {
        let mut error_msg = String::new();
        if let Some(s) = err.downcast_ref::<String>() {
            tracing::error!("Service panicked: {}", s);
        } else if let Some(s) = err.downcast_ref::<&str>() {
            tracing::error!("Service panicked: {}", s);
        } else {
            tracing::error!(
                 "Service panicked but `CatchPanic` was unable to downcast the panic info"
             );
        };

        let mut res = Response::new(Full::from("Service panicked."));
        res.headers_mut().insert(
            header::CONTENT_TYPE,
            header::HeaderValue::from_static("text/plain; charset=utf-8"),
        );
        // res.(StatusCode::INTERNAL_SERVER_ERROR);
        *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;

        res
    }
}

もっとシンプルにも書ける。
(CatchPanicLayer::customの引数の型Tは、where T: ResponseForPanic だけど、関数も指定できるようになってる)

use tower_http::catch_panic::{CatchPanicLayer, ResponseForPanic};

    // Compose the routes
    let app = router.layer(
        ServiceBuilder::new()
            .layer(CatchPanicLayer::custom(handle_panic))
use axum::response::{IntoResponse, Response},

fn handle_panic(err: Box<dyn std::any::Any + Send + 'static>) -> axum::response::Response<Body> {
    let details = if let Some(s) = err.downcast_ref::<String>() {
        s.clone()
    } else if let Some(s) = err.downcast_ref::<&str>() {
        s.to_string()
    } else {
        "Unknown panic message".to_string()
    };

    let body = serde_json::json!({
        "error": {
            "kind": "panic",
            "details": details,
        }
    });
    let body = serde_json::to_string(&body).unwrap();

    Response::builder()
        .status(StatusCode::INTERNAL_SERVER_ERROR)
        .header(header::CONTENT_TYPE, "application/json")
        .body(Body::from(body))
        .unwrap()
}
yunayuna

reqwestの responseからaxumのresponseへ変換したい

axum apiをproxyとして経由し、別URLにアクセスする場合など。

impl IntoResponse を満たす返り値にするため、
(http::status::StatusCode, http::header::HeaderMap, bytes::Bytes)を返す
※axum_coreの pub trait IntoResponse 参照
https://github.com/tokio-rs/axum/blob/main/axum-core/src/response/into_response.rs

pub async fn sample_api(
    Extension(mongo_database): Extension<&mongodb::Database>,
    Query(query): Query<serde_json::Value>,
) -> ApiResult<impl IntoResponse> {

let res = reqwest::Client::builder().build().unwrap()
        .get(
            "https://xxxxx.xxxxx.com/"
        )
        .query(&query)
        .timeout(Duration::from_secs(45))
        .send()
        .await?;

    let status_code = res.status();
    Ok((
        status_code,
        res.headers().clone(),
        res.bytes().await.unwrap(),
    ))
}
yunayuna

axumでSSE(Server Send Event)実装

axumで、SSEのサーバー側実装のやり方。

client側の接続が切れた時の対応を含めた実装方法はこちら。
DropをimplementしたGuard structを作成して、streamを返す処理を書く。

こうすることで、client側の接続が切れると自動的にdrop functionが呼ばれるので
clientの接続状態を管理したりできる。

参照:
https://github.com/tokio-rs/axum/discussions/1060

event_source.rs
use axum::response::sse::{Event, Sse, KeepAlive};

async fn sse_handler() -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
    struct Guard {
        // whatever state you need here
    }

    impl Drop for Guard {
        fn drop(&mut self) {
            tracing::info!("stream closed");
        }
    }

    let stream = async_stream::stream! {
        let _guard = Guard {};
        let mut interval = tokio::time::interval(Duration::from_secs(1));
        loop {
            interval.tick().await;
            yield Ok(Event::default().data("hi"));
        }
        // `_guard` is dropped
    };

    Sse::new(stream)
}

axumでget apiとして設定

route.rs
.route("/sse_handler", get(connection::sse_handler))

rustでclient側実装する場合

rustでclient側実装例はこちら。
こちらのcrateを利用
https://github.com/jpopesculian/reqwest-eventsource

こちらの例ではErr時にes.close()しているが、
closeしなければ、自動的に再接続しに行ってくれます。

let mut es =reqwest_eventsource:: EventSource::get("http://localhost:8080/sse_handler");
        while let Some(event) = es.next().await {
            match event {
                Ok(reqwest_eventsource::Event::Open) => println!("Connection Open!"),
                Ok(reqwest_eventsource::Event::Message(message)) => {
                    println!("Message: {:#?}", message)
                },
                Err(err) => {
                    println!("Error: {}", err);
                    es.close();
                }
            }
        }

リアルタイム接続通信はwebsocket が有名ですが、
こちらも、とても簡単な印象です。

加えて、従来のhttp接続なので、firewallなどの環境による影響が少ないのがメリットです。
接続数が大きくなった場合のサーバー側の制限などには要注意。

注意点

chromeのsseでテストしたところ、
axumの上記実装で、接続はできるのですが、データの受信がなぜかリアルタイムにできませんでした。
rustのsse clientとの通信は問題なく動く事を確認したので、もしかするとchromeとの通信に必要なヘッダーがsse responseに不足している等、
chrome側の要件を満たしていない可能性があります。
(深掘りする時間が無いので、取り急ぎこの問題は保留に)


自己解決しました。
もともと、
.layer(CompressionLayer::new())
で、すべての処理に対して圧縮をかけていたことで、ブラウザはsseのeventSourceが認識できなくなっていた。

そこで、圧縮対象からsseの処理(headerのcontent-typeがtext/event-streamの場合)を除外することで、問題発生しなくなりました。

use tower_http::compression::{
    Compression, CompressionLayer
    predicate::{Predicate, NotForContentType, DefaultPredicate},
};

let app = Router::new()
        .route("/sse", get(sse_handler))
        .layer(CompressionLayer::new().compress_when(DefaultPredicate::new()
                  .and(NotForContentType::new("text/event-stream"))))

詳細はこちらでやり取りしています
https://github.com/tokio-rs/axum/discussions/2034

さらに、sse通信を行う際、中継点でbuffering(メッセージをある程度溜めて一括処理する仕組み)されて、
リアルタイムでメッセージが流れてこない場合があります。

例:proxyにnginxを使っている場合

自動的にbufferingされてしまうので、
nginxがbufferingされないよう、ヘッダーにx-accel-buffering=noを追加する必要が有ります。

axum側では、以下のように対応できます。

※axum のsseモジュールは、headerを編集する機能が無いので、
responseを取得して自分でheaderを追加しています。

さらに、axumのcontrollerは、impl IntoResponseである必要があり、responseをそのまま返せないので、
IntoResponseを実装したwrapperを使っています。(もっといいやり方あったらコメントください)

models
struct ResponseWrapper{ 
    response:axum::response::Response,
}
impl IntoResponse for ResponseWrapper {
    fn into_response(self) -> axum::response::Response {
        self.response
    }
}
router.rs
.route("/machine_connect_status", get(controller::sse_connect))
controller.rs
pub async fn sse_connect(
    Query(query): Query<serde_json::Value>,
// ) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
) -> ApiResult<impl IntoResponse> {
    // make stream
    // let stream = ...(省略)

    let mut res = Sse::new(stream).keep_alive(
        axum::response::sse::KeepAlive::default()
            .interval(Duration::from_secs(1)),
    ).into_response();
    let mut headers = res.headers_mut();
    headers.insert(hyper::http::header::CONNECTION, hyper::http::HeaderValue::from_static("keep-alive"));
     // HeaderName::from_staticは、lowercaseでないとエラーになるっぽいので注意
    headers.insert(hyper::http::header::HeaderName::from_static("x-accel-buffering"), hyper::http::HeaderValue::from_static("no")); //for nginx 
    Ok(ResponseWrapper{response: res})
)
yunayuna

axum serverでwebsocket実装

websocketのserver側実装例です。

仕様として、clientからのメッセージが1分間受信されないか(pingなどを常時打つことで死活確認を行っています)、
closeメッセージを受信したら、websocketの受信・送信streamを終了します。

公式URLには、timeout処理が記載されていなかったので、tokio::time::timeoutを使って実装してみました。

公式
https://docs.rs/axum/latest/axum/extract/ws/

router.rs

.route("/websocket_connection", get(controller::ws_machine_connect_status))

controller.rs
#[tracing::instrument]
pub async fn ws_machine_connect_status(
    ws: WebSocketUpgrade, 
    Query(query): Query<serde_json::Value>,) -> ApiResult<impl IntoResponse> {
    let machine_id = query.get("machine_id").unwrap().as_str().unwrap().to_string();
    Ok(ResponseWrapper{response: ws.on_upgrade(|socket| ws_machine_connect_status_socket(socket, machine_id))})
}


async fn ws_machine_connect_status_socket(socket: WebSocket, machine_id: String) {
    let (mut sender, mut receiver) = socket.split();

    let sender_handle = tokio::spawn(write(sender, machine_id.clone()));
    let reciever_handle = tokio::spawn(read(receiver, sender_handle.abort_handle(), machine_id.clone()));
}
async fn read(mut receiver: SplitStream<WebSocket>, sender_handle: tokio::task::AbortHandle, machine_id: String) {
    //60秒pingが無ければ、接続が切れたと判定して処理を終了する
    let timeout = Duration::from_secs(60);

    while let result = tokio::time::timeout(timeout, receiver.next()).await {
        match result {
            // A new message has been received
            Ok(Some(Ok(message))) => match message {
                axum::extract::ws::Message::Close(frame) => {
                    break;
                }
                _ => {
                    println!("処理を継続");
                }
            },
            // An error occurred while trying to read a message
            Ok(Some(Err(_e))) => {
                println!("_e: {:?}", _e);
                break;
            }
            Ok(None) => {
                println!("receive nothing.");
            }
            // Timeout occurred
            Err(_e ) => {
                println!("timeout! {:?}", _e);
                break;
            }
        }
    }
    //receive処理を抜けるタイミングで、sender側の処理も終了する
    sender_handle.abort();

}

async fn write(mut sender: SplitSink<WebSocket, Message>, machine_id: String) {
    let mut interval = tokio::time::interval(Duration::from_secs(1));
    loop {
        interval.tick().await;

        let res = sender.send(Message::Text(format!("hello, machine_id:{:?}", machine_id.as_str()))).await;
    }
}

yunayuna

ファイルアップロードに関する実装例

axumの、multipart featrueを利用します

Cargo.toml
axum = {version="0.6.18", features = ["multipart"]}

ファイル上限サイズの設定

router.rs
use axum::extract::DefaultBodyLimit;

// ........

.route("/upload", post(controller::upload));
.layer(DefaultBodyLimit::max(2 * 1000 * 1000 * 1000))//max length: 2GB

// ........
controller.rs
use axum::extract::Multipart;

pub async fn upload(
    mut multipart: Multipart
) -> ApiResult<impl IntoResponse> {

    while let Some(mut field) = multipart.next_field().await.unwrap() {
        let name = field.name().unwrap().to_string();
        let data = field.bytes().await.unwrap();

        println!("Length of `{}` is {} bytes", name, data.len());
    }
    Ok(())
}
yunayuna

ファイルダウンロードに関する実装例(WIP)

関連資料
https://github.com/tokio-rs/axum/discussions/1638
https://docs.rs/axum/latest/axum/body/struct.StreamBody.html
https://docs.rs/tokio-util/latest/tokio_util/io/struct.ReaderStream.html
https://stackoverflow.com/questions/73325707/how-to-return-contents-as-a-file-download-in-axum

controller.rs
#[tracing::instrument]
#[debug_handler]
pub async fn get_uploaded_file(
    Query(query): Query<serde_json::Value>,
) -> ApiResult<impl IntoResponse> {
    let id: String = query.get("id").unwrap().to_string();

    if let Some((file_name, temp_file))  = store::pop_uploaded_file(id.as_str()) {
        let tmp_file_stream = tokio::fs::File::open(temp_file).await.unwrap();
        let stream = tokio_util::io::ReaderStream::new(tmp_file_stream);
        let body = hyper::Body::wrap_stream(stream);
        let mut response = Response::builder().body(boxed(body)).unwrap();

        let mut headers = response.headers_mut();
        headers.insert(
            hyper::header::CONTENT_TYPE,
            hyper::header::HeaderValue::from_static("application/octet-stream"),
        );
        headers.insert(
            hyper::header::CONTENT_DISPOSITION,
            hyper::header::HeaderValue::from_str(format!("attachment; filename=\"{}\"", file_name).as_str()).unwrap(),
        );

        Ok(ResponseWrapper{response})
    } else {
        Ok(ResponseWrapper{response: Json(serde_json::json!({"message": "error"})).into_response()})
    }
}

yunayuna

handlerのエラー事例①

router.rs
#↓ここでコンパイルエラー
                .route(
                    "/password_reset",
                    post(controller::user::password_reset),
                ),



306 |                     post(controller::user::password_reset),
    |                     ---- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ the trait `Handler<_, _, _>` is not implemented for fn item `fn(Extension<db::Database>, axum::Json<serde_json::Value>) -> impl futures_util::Future<Output = Result<impl IntoResponse, ApiError>> {password_reset}`
    |                     |
    |                     required by a bound introduced by this call
    |
    = help: the following other types implement trait `Handler<T, S, B>`:
              <MethodRouter<S, B> as Handler<(), S, B>>
              <axum::handler::Layered<L, H, T, S, B, B2> as Handler<T, S, B2>>
note: required by a bound in `post`
   --> /home/masatoyuna/.cargo/registry/src/index.crates.io-6f17d22bba15001f/axum-0.6.20/src/routing/method_routing.rs:407:1
    |
handler.rs
#[tracing::instrument]
pub async fn password_reset(
    Extension(ref database): Extension<crate::infrastructure::db::Database>,
    Json(post_data): Json<serde_json::Value>,
) -> ApiResult<impl IntoResponse> {

        // ランダムなバイト列を生成(value)
        let mut rng = rand::thread_rng();
        let value_bytes: [u8; 32] = rng.gen();
        // トークン生成
        let hash = gwutil::crypt::hash_password(value_bytes)?;

        // 省略 /////////////////////////////////////

        Ok(Json(serde_json::json!({ "result": true })))
    
}

原因

rand crateの
rand::thread_rng()部分が原因。
thread_rngは、システムからsedを得て、遅延初期化されたスレッドローカルの乱数生成器を取得し、ThreadRngインスタンスを返す。

このThreadRngが、Rcを含むため(スレッドセーフでない)、非同期関数の中で使われることによりコンパイルエラーになっている(?)っぽい。エラーメッセージから判断が難しいのでchatGPTと相談しながら突き止めた。

pub struct ThreadRng {
    // Rc is explicitly !Send and !Sync
    rng: Rc<UnsafeCell<ReseedingRng<Core, OsRng>>>,
}

対策

ランダム生成時、rand::thread_rng()の代わりに rand::StdRng::from_entropy(); を使う。

handler.rs
use rand::{Rng, rngs::StdRng, SeedableRng};

#[tracing::instrument]
pub async fn password_reset(
    Extension(ref database): Extension<crate::infrastructure::db::Database>,
    Json(post_data): Json<serde_json::Value>,
) -> ApiResult<impl IntoResponse> {

        // ランダムなバイト列を生成(value)
        let mut rng = StdRng::from_entropy();
        let value_bytes: [u8; 32] = rng.gen();
        // トークン生成
        let hash = gwutil::crypt::hash_password(value_bytes)?;

        // 省略 /////////////////////////////////////

        Ok(Json(serde_json::json!({ "result": true })))
    
}
yunayuna

ServiceBuilderのlayerは、通常の逆になる

routerにlayerでチェインしていくと、後ろから読み込まれていく

router
 .layer(CatchPanicLayer::custom(handle_panic)) //layer1
 .layer(middleware::from_fn(
                controller::layer::tracing::request_id,
            )) //layer2

ServiceBuilderでlayerをチェインしていくと、頭から読み込まれていく

router.layer(
        ServiceBuilder::new()
            .layer(CatchPanicLayer::custom(handle_panic)) //layer1
            .layer(middleware::from_fn(
                controller::layer::tracing::request_id,
            )) //layer2
 ...
)

なので、例えばextensionに値をセットする場合の順序に気をつける必要がある。