🏢

AxumでのExtensionの使い方 - ミドルウェアとデータを共有する -

2024/05/27に公開

動作環境

本記事は以下の環境で動作確認を行なっています。
Rust, Axum共に最新の機能を使っているわけではないので、これよりもある程度古いバージョンでも動作すると思われます。

  • Rust 1.78.0
  • Axum 0.7.5

Extension

axumには Extension という型があります。

ドキュメントに多くは書かれていないのですが、それには以下のようにあります。

Extractor and response for extensions.
https://docs.rs/axum/0.7.5/axum/struct.Extension.html

extensionsに対するextractorとresponseらしいです。
ここで言うextensionsとは、 http crate のExtensionsのことです。
ドキュメントには以下のようにあります。

A type map of protocol extensions.
Extensions can be used by Request and Response to store extra data derived from the underlying protocol.
https://docs.rs/http/1.1.0/http/struct.Extensions.html

要するに「型情報をキーとしたMap(内部的には HashMap)」「リクエストとレスポンスに付加的な情報を保持できる」というものです。

axumの Extension は、ルーター/ミドルウェア/ハンドラーの間でデータを共有するために使います。
ミドルウェアとハンドラーの間でデータを共有したくなるケースとして最も多いものの一つに、認証機能(ログイン機能)があります。
「ハンドラー群に認証をかけたいが、ハンドラー側ではユーザ情報が必要だったり不要だったりする」という場合、 Extension を使うことで簡潔に実装することができます。
今回はBearer認証を Extension を使って実装することで、その使い方を解説します。

実装

コード全体は本記事最下部にあります。

ユーザプール

今回は簡易的なユーザプールとして HashMap を使います。
keyはトークン文字列で、valueはユーザ名を持つ User です。
サンプルのユーザプールを作る関数も定義しておきます。

pub type Token = String;
pub type UserMap = HashMap<Token, User>;

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct User {
    pub username: String,
}

pub fn build_user_map() -> UserMap {
    let mut user_map = HashMap::new();
    user_map.insert(
        "aaa".to_string(),
        User {
            username: "Andy".to_string(),
        },
    );
    user_map.insert(
        "bbb".to_string(),
        User {
            username: "Bella".to_string(),
        },
    );
    user_map.insert(
        "ccc".to_string(),
        User {
            username: "Callie".to_string(),
        },
    );
    user_map.insert(
        "ddd".to_string(),
        User {
            username: "Daren".to_string(),
        },
    );
    user_map
}

UserFromRequestParts を実装する

UserFromRequestParts を実装することで、ハンドラーの引数で User を受け取ることができるようになります。

ここでは認証/認可は行われません。

処理としては、後述するミドルウェアで仕込んだ User を取り出しています。
もし、リクエスト中に User が無い場合、ミドルウェアが有効になっていないなどのバグのため、 expect を使っています。

#[async_trait]
impl<S> FromRequestParts<S> for User
where
    S: Send + Sync,
{
    type Rejection = StatusCode;

    async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
        let user = parts
            .extensions
            .get::<Self>()
            .expect("User not found. Did you add auth_middleware?");
        Ok(user.clone())
    }
}

ミドルウェアを実装する

このミドルウェアによって、認証が行われます。
処理としては以下の流れです。

  1. リクエストヘッダーからBearerトークンを取り出す
  2. ユーザプールから対応するユーザを探す
  3. ユーザが見つかればリクエストの ExtensionUser を追加する

3で追加した User が前述の FromRequestParts の実装で取り出されることになります。

pub async fn auth_middleware(
    State(user_map): State<Arc<UserMap>>,
    mut request: Request,
    next: Next,
) -> axum::response::Result<Response> {
    let bearer = request
        .extract_parts::<TypedHeader<Authorization<Bearer>>>()
        .await
        .map_err(|_| StatusCode::BAD_REQUEST)?;
    let token = bearer.token();

    let user = user_map.get(token).ok_or(StatusCode::UNAUTHORIZED)?;
    request.extensions_mut().insert(user.clone());

    Ok(next.run(request).await)
}

ミドルウェアを使う

axum::middleware::from_fn_with_state を使ってミドルウェアを組み込みます。
それ以外はaxumの一般的な使用方法です。

3つのハンドラーがあり、それぞれ以下の想定です。

  • public: 認証が不要
  • private: 認証が必要、ユーザ情報不要
  • your_name: 認証が必要、ユーザ情報必要

private ハンドラーの引数には User がありませんが、ミドルウェアによって認証がかけられています。

#[tokio::main]
async fn main() {
    let user_map = build_user_map();
    let app = build_app(Arc::new(user_map));
    let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
    axum::serve(listener, app).await.unwrap();
}

#[rustfmt::skip]
fn build_app(user_map: Arc<UserMap>) -> Router {
    let public_router = Router::new()
        .route("/public", get(public));

    let private_router = Router::new()
        .route("/private", get(private))
        .route("/your-name", get(your_name))
        .route_layer(from_fn_with_state(user_map.clone(), auth_middleware));

    Router::new()
        .nest("/", public_router)
        .nest("/", private_router)
        .with_state(user_map)
}

async fn public() -> &'static str {
    "This is public."
}

async fn private() -> &'static str {
    "This is private."
}

async fn your_name(user: User) -> String {
    format!("Your name is {}.", user.username)
}

動作確認

/public は認証せずともアクセスが可能です。

$ curl http://localhost:3000/public
This is public.

/private は認証が必要です。

$ curl -I http://localhost:3000/private
HTTP/1.1 400 Bad Request
content-length: 0
date: Mon, 27 May 2024 09:51:24 GMT
$ curl -H 'Authorization: Bearer aaa' http://localhost:3000/private
This is private.

/your-name は認証が必要です。

$ curl -I http://localhost:3000/your-name
HTTP/1.1 400 Bad Request
content-length: 0
date: Mon, 27 May 2024 09:53:13 GMT
$ curl -H 'Authorization: Bearer bbb' http://localhost:3000/your-name
Your name is Bella.

補足

Extensionの可視性

Extension は特定のリクエスト・レスポンスに固有のため、他のリクエスト・レスポンスには共有されません。
そのため、ミドルウェアで追加した User が別のリクエストから見えることはありません。

Just locally in the request. Different requests don't share the same extensions.
https://github.com/tokio-rs/axum/discussions/2312#discussioncomment-7558458

Extensionの実装

Extension は型情報をキーとする HashMap になっており、興味深い実装です。

ミドルウェアの適用方法

今回はミドルウェアを組み込む際に from_fn_with_state を使いましたが、他にも from_fnfrom_extractor などの関数があります。
また、ミドルウェアの Router::layer は順序が重要なため、一度 axum::middleware のドキュメントを読むことをお勧めします。

https://docs.rs/axum/latest/axum/middleware/index.html

まとめ

今回はBearer認証を実装することで Extension の使い方を解説しました。
Extension を使った手法は実際に、 axum-login crate でも使われています。

Extension の利用ケースとしては、認証/認可の他にも「リクエストIDの生成・利用」「ロギングを行うミドルウェアとの情報共有」など、ミドルウェアとの連携が必要なものが考えられます。
汎用的な仕組みのため、axumを使いこなす上で重要な機能だと言えるでしょう。

コード全体

依存ライブラリ

axum = { version = "0.7.5" }
axum-extra = { version = "0.9.3", features = ["typed-header"] }
http = "1.1.0"
tokio = { version = "1", features = ["full"] }

auth.rs

use std::{collections::HashMap, sync::Arc};

use axum::{
    async_trait,
    extract::{FromRequestParts, Request, State},
    middleware::Next,
    response::Response,
    RequestExt as _,
};
use axum_extra::{
    headers::{authorization::Bearer, Authorization},
    TypedHeader,
};
use http::{request::Parts, StatusCode};

pub type Token = String;
pub type UserMap = HashMap<Token, User>;

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct User {
    pub username: String,
}

pub fn build_user_map() -> UserMap {
    let mut user_map = HashMap::new();
    user_map.insert(
        "aaa".to_string(),
        User {
            username: "Andy".to_string(),
        },
    );
    user_map.insert(
        "bbb".to_string(),
        User {
            username: "Bella".to_string(),
        },
    );
    user_map.insert(
        "ccc".to_string(),
        User {
            username: "Callie".to_string(),
        },
    );
    user_map.insert(
        "ddd".to_string(),
        User {
            username: "Daren".to_string(),
        },
    );
    user_map
}

#[async_trait]
impl<S> FromRequestParts<S> for User
where
    S: Send + Sync,
{
    type Rejection = StatusCode;

    async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
        let user = parts
            .extensions
            .get::<Self>()
            .expect("User not found. Did you add auth_middleware?");
        Ok(user.clone())
    }
}

pub async fn auth_middleware(
    State(user_map): State<Arc<UserMap>>,
    mut request: Request,
    next: Next,
) -> axum::response::Result<Response> {
    let bearer = request
        .extract_parts::<TypedHeader<Authorization<Bearer>>>()
        .await
        .map_err(|_| StatusCode::BAD_REQUEST)?;
    let token = bearer.token();

    let user = user_map.get(token).ok_or(StatusCode::UNAUTHORIZED)?;
    request.extensions_mut().insert(user.clone());

    Ok(next.run(request).await)
}

main.rs

mod auth;

use std::sync::Arc;

use axum::{middleware::from_fn_with_state, routing::get, Router};

use auth::{auth_middleware, build_user_map, User, UserMap};

#[tokio::main]
async fn main() {
    let user_map = build_user_map();
    let app = build_app(Arc::new(user_map));
    let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
    axum::serve(listener, app).await.unwrap();
}

#[rustfmt::skip]
fn build_app(user_map: Arc<UserMap>) -> Router {
    let public_router = Router::new()
        .route("/public", get(public));

    let private_router = Router::new()
        .route("/private", get(private))
        .route("/your-name", get(your_name))
        .route_layer(from_fn_with_state(user_map.clone(), auth_middleware));

    Router::new()
        .nest("/", public_router)
        .nest("/", private_router)
        .with_state(user_map)
}

async fn public() -> &'static str {
    "This is public."
}

async fn private() -> &'static str {
    "This is private."
}

async fn your_name(user: User) -> String {
    format!("Your name is {}.", user.username)
}

Discussion