🦀

Rust の Web フレームワーク Axum のミドルウェアと戯れる

2022/12/04に公開約10,200字2件のコメント

はじめに

これは Rust Advent Calendar 2022 4 日目の記事です。API サーバーを作る時にリクエストごとに一意な ID を割り振りたいケースがあると思います。Rust の Web フレームワーク Axum でそういったリクエスト ID を実現する方法を調べました。またアクセスログを出力するのもギリギリ間に合ったのでその方法も調べました。

Axum でリクエスト ID を付与する&アクセスログを出力する

ここからは

  1. Axum で各 HTTP リクエストにユニークな ID を付与する
  2. 各 HTTP リクエストの際にアクセスログを出力する
    ことを目標に進めていきます。

Axum とは

GitHub リポジトリはこちら。Rust の非同期ランタイムで有名な tokio で作られていて、Tokio, Tower, Hyper を裏で使っています。example が充実していて、やりたいことと似ているものを探してコピペすれば最初の取っ掛かりで苦労することはあまりないです。

素朴に実装するには

Axum の README に書いている example を眺めると、ルーターにパスとメソッドとそのメソッドの時に呼ばれるハンドラーを登録していることがわかります。

#[tokio::main]
async fn main() {
    tracing_subscriber::fmt::init();

    let app = Router::new()
        // `GET /` だと root を呼ぶ
        .route("/", get(root))
        // `POST /users` だと create_user を呼ぶ
        .route("/users", post(create_user));

    let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
    tracing::debug!("listening on {}", addr);
    axum::Server::bind(&addr)
        .serve(app.into_make_service())
        .await
        .unwrap();
}

ということはそれぞれのハンドラーにリクエスト ID を割り振る処理を追加すればいいのでは?と最初に思いつきます。これでもうまくいきますができれば他の方法を使いたいでしょう。なぜなら全てのハンドラーにリクエスト ID を割り振る処理があることを確認するのは大変(コードレビューのタイミングでチェックするとかになりそう)ですし、ハンドラーの責務はリクエストに対してビジネスロジックに基づいてレスポンスを返すことで、リクエスト ID を割り振るのは関心の範囲外だからです。

ミドルウェア

ではどうすればいいかというと他の言語やフレームワークで言うところのミドルウェアを使えばいいです。Go 言語のウェブアプリケーションフレームワーク Echo には ミドルウェアはありますし、Python のウェブアプリケーションフレームワーク FastAPI にもミドルウェアがあります。ミドルウェアの処理が終わってからハンドラーの処理が行われるので、ミドルウェアの中でどのハンドラーでも共通の処理(今回の場合はリクエスト ID の付与ですがほかにもロギングなどが考えられます)を定義することができます。

仕組みとしては単純で例えば Echo でカスタムのミドルウェアを書く例がわかりやすく、関数を引数に取る関数を書けばいいです。

関数を引数に取る関数と簡単に書きましたが Rust の場合これをうまく抽象化する必要がありそこで出てくるのが tower::Service トレイトです。このトレイトがどのようにデザインされているかや実際にミドルウェアをスクラッチで実装するにはどうすればいいかがこの Tower のガイドに詳しく書かれています。ただ今回の本筋ではないので次に進みます。

tower_http::request_id モジュール

Axum でミドルウェアをどのように実装するかを調べるために axum middleware とかで検索しているとドキュメントのページが見つかりました。そしてよく使われるミドルウェアのところに RequestIdLayer といういかにもそれっぽい名前のミドルウェアが見つかりました。この example axum の middleware example を見比べてみると、なんとなく以下のようなコードでうまくいきそうな気がしてきます。

let app = Router::new()
    .route("/", get(handler))
    .layer(
        ServiceBuilder::new()
            .layer(SetRequestIdLayer::new(
                x_request_id.clone(),
                MyMakeRequestId::default(),
        ))
    .layer(PropagateRequestIdLayer::new(x_request_id)))
    ;

ただこれはあくまでいけそうな気がするだけなのでもう少し詳しく各レイヤーの中身を見ていきます。

まず SetRequestIdLayer をみてみると new するときの第一引数は HeaderName で第二引数にトレイト境界構文が使われています。またそもそも x_request_id を使えばヘッダー名を x-request-id に決めうちするのと引き換えに第一引数が不要になることがわかります。ただ今回は何をやっているかわかりやすくするためにあえてヘッダー名を明示的に指定して new の方を使っていきます。

次に第二引数に出てきたトレイト境界から MakeRequestId をみてみます。ざっくり HTTP リクエストから RequestId のオプション型を返していることがわかります。RequestId を作るには引数が HeaderValue である必要があり、HeaderValue は str から作れるみたいですね。

そういえば今までリクエスト ID を割り振るとしか言っていませんでした。example のコピペだとつまらないのでリクエスト ID には uuid を使うことにします。uuid を生成して文字列に変換し、そこから HeaderValue を作って、それを基に Option<RequestId> を作って返せばいいでしょう。

またリクエストの x-request-id ヘッダーをレスポンスに伝播させるために PropagateRequestIdLayer が必要です。さらにこのようにレイヤーが複数ある場合はlayer を複数回呼ぶのではなく一度に ServiceBuilder を使って適用することが推奨されているので注意しておきましょう。

アクセスログを出力する

同じように各ハンドラー共通でアクセスログを出力できるようにします。とりあえず GET / みたいな感じで HTTP メソッドとその時のパスを出力できるようにします。せっかくなので別の方法でミドルウェアを書きましょう。ここに書いているように async/await 構文で処理を記述したい・クレートとして公開するつもりがなくて axum で使えるだけでいい場合には axum::middleware::from_fn を使うことができます。これはドキュメントを読むとイメージがつかみやすくて、async 関数からミドルウェアを作り出すものです。以下のように書けば HTTP リクエストからメソッドとパスを取り出して出力し、その後リクエストの結果を返す処理を実行できます。

async fn access_log_on_request<B>(
    req: Request<B>,
    next: Next<B>,
) -> Result<Response, StatusCode> {
    info!("{} {}", req.method(), req.uri());
    Ok(next.run(req).await)
}

最終的なコード

これらを踏まえて以下の手順でコードを書いていきます。

$ cargo --version
cargo 1.65.0 (4bc8f24d3 2022-10-20)
$ cargo new axum-sample-server
Created binary (application) `axum-sample-server` package

まず Cargo.toml を編集します。

Cargo.toml
[package]
name = "axum-sample-server"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
axum = "0.6.1"
http = "0.2.8"
http-body = "0.4.5"
hyper = "0.14.23"
serde = { version = "1.0.147", features = ["derive"] }
tokio = { version = "1.21.2", features = ["full"] }
tower = "0.4.13"
# for tower_http::request_id
tower-http = { version = "0.3.4", features = ["request-id"] }
tower-layer = "0.3.2"
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.16", features = ["fmt","std","json"] }
uuid = "1.2.2"

次に main.rs を編集します。

main.rs
use axum::{
    http::StatusCode, middleware::Next, response::IntoResponse, response::Response,
    routing::get, Json, Router
};
use http::{header::HeaderName, header::HeaderValue, Request};
use tower::ServiceBuilder;
use tower_http::request_id::{
    MakeRequestId, PropagateRequestIdLayer, RequestId, SetRequestIdLayer,
};
use tracing::info;
use uuid::Uuid;
use serde::Serialize;
use std::net::SocketAddr;
use tracing_subscriber::fmt;

#[tokio::main]
async fn main() {
    run().await;
}

async fn run() {
    let event_format = fmt::format::json();
    tracing_subscriber::fmt().event_format(event_format).init();

    let x_request_id = HeaderName::from_static("x-request-id");
    let app = Router::new()
        .route("/", get(health_check))
        .layer(
            ServiceBuilder::new()
                .layer(SetRequestIdLayer::new(
                    x_request_id.clone(),
                    MyRequestId::new(),
                ))
                .layer(PropagateRequestIdLayer::new(x_request_id))
                .layer(axum::middleware::from_fn(
                    access_log_on_request,
                )),
        );

    let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
    tracing::debug!("listening on {}", addr);
    axum::Server::bind(&addr)
        .serve(app.into_make_service())
        .await
        .unwrap();
}

#[derive(Serialize)]
struct JsonMessage {
    message: String,
}

async fn health_check() -> impl IntoResponse {
    let m = JsonMessage {
        message: String::from("Healthy"),
    };
    Json(m)
}

#[derive(Clone)]
pub struct MyRequestId {}

impl MyRequestId {
    pub fn new() -> Self {
        MyRequestId {}
    }
}

impl MakeRequestId for MyRequestId {
    fn make_request_id<B>(&mut self, _request: &Request<B>) -> Option<RequestId> {
        // generate uuid at every request
        let uuid = Uuid::new_v4().to_string();
        let request_id = HeaderValue::from_str(&uuid).unwrap();

        Some(RequestId::new(request_id))
    }
}

async fn access_log_on_request<B>(req: Request<B>, next: Next<B>) -> Result<Response, StatusCode> {
    info!("{} {}", req.method(), req.uri());
    Ok(next.run(req).await)
}

最後にコンパイルしてサーバーを起動します。

$ cargo build
Finished dev [unoptimized + debuginfo] target(s) in 0.04s
$ target/debug/axum-sample-server

これで以下のように curl でリクエストを送ってみるとヘッダーにリクエスト ID が付与されていて、サーバーを実行している方のターミナルにはアクセスログが出力されます。

# curl でリクエストを送る
$ curl -i http://localhost:3000
HTTP/1.1 200 OK
content-type: application/json
content-length: 21
x-request-id: 5c5055a2-a75f-462a-9221-7bda7017a850
date: Sun, 04 Dec 2022 13:39:11 GMT

{"message":"Healthy"}

# サーバー起動している方のターミナル
{"timestamp":"2022-12-04T13:39:11.419510Z","level":"INFO","fields":{"message":"GET /"},"target":"axum_sample_server"}

わかっていないこと

上のコードだとルーターを初期化する時に ServiceBuilder で作成したミドルウェアを登録しています。本当はこれらの処理を分けたかったです。イメージ的には以下のような感じ:

router.rs
let middlewares = build_middlewares()
let app = Router::new()
        .route("/", get(health_check))
        .layer(middlewares);
middleware.rs
fn build_middlewares() -> ServiceBuilder {
    let x_request_id = HeaderName::from_static("x-request-id");

    ServiceBuilder::new()
        .layer(SetRequestIdLayer::new(
            x_request_id.clone(),
            MyRequestId::new(),
        ))
        .layer(PropagateRequestIdLayer::new(x_request_id))
}

こうするときに build_middlewares の返り値をどう定義すればいいかわかりませんでした。Router().layer の定義を読む限りだと Layer を返せば良さそうで `tower_layer::Layer の定義を見ると Layer と Service が紐づいてそう、でとりあえず以下のようにトレイト境界を書いてみたけどうまく動かず。。。とりあえず ServiceBuilder の第二引数だけ切り出して終わっています。ここの書き方がわかればスッキリするのだろうけど力不足でした。

fn build_middlewares<L, B, NewReqBody>() -> impl Layer<Route<B>>
where
    L: Layer<Route<B>> + Clone + Send + 'static,
    L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static,
    <L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static,
    <L::Service as Service<Request<NewReqBody>>>::Error: Into<Infallible> + 'static,
    <L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
    NewReqBody: HttpBody + 'static,
    B: HttpBody + Send + 'static,

おわりに

今まで Echo のミドルウェアは使ったことがあって同じノリで axum も探してみました。これくらいならコピペ+α でなんとかなってよかったです。

Discussion

頑張って build_middleware() を定義するなら,

use axum::body::Body;
use axum::routing::Route;
use core::convert::Infallible;
use tower::Layer;
use tower::Service;

fn build_middleware() -> impl Layer<
    Route<Body>,
    Service = impl Service<
        Request<Body>,
        Response = impl IntoResponse + 'static,
        Error = impl Into<Infallible> + 'static,
        Future = impl Send + 'static,
    > + Clone
                  + Send
                  + 'static,
> + Clone
       + Send
       + 'static {
    let x_request_id = HeaderName::from_static("x-request-id");
    ServiceBuilder::new()
        .layer(SetRequestIdLayer::new(
            x_request_id.clone(),
            MyRequestId::new(),
        ))
        .layer(PropagateRequestIdLayer::new(x_request_id))
        .layer(axum::middleware::from_fn(access_log_on_request))
}

のようになると思います.Body に関しては,

  • Router::new() で型を指定しない場合,default で Router<hyper::body::Body>となる
  • axum の中で B = hyper::body::Body しか受け付けないものがあった気がする

ので,簡単のために B = NewReqBody = hyper::body::Body としてしまっています.

ただ私なら,ServiceBuilder を追加するコマンド add_service_builder(router: Router<Body>) -> Router<Body> を定義するか,自分で Service/Layer を impl するかしますね.Service/Layer の単純な wrapper を書くだけなら,そこまで実装コストはかからないです.(面倒ではあります…….)

Naughie さん
コメントありがとうございます.教えていただいた方法でコンパイルできました.
もともとルーターを初期化する処理とミドルウェアを追加する処理をわけたいだけなので,おっしゃるように add_service_builder 的なコマンドを定義したほうがよさそうですね.今回やりたいことだとそこまで実装コストがかからないとはいえ Service/Layer を wrap するのは面倒なので

ログインするとコメントできます