AxumでのExtensionの使い方 - ミドルウェアとデータを共有する -
動作環境
本記事は以下の環境で動作確認を行なっています。
Rust, Axum共に最新の機能を使っているわけではないので、これよりもある程度古いバージョンでも動作すると思われます。
- Rust 1.78.0
- Axum 0.7.5
Extension
axumには Extension という型があります。
ドキュメントに多くは書かれていないのですが、それには以下のようにあります。
Extractor and response for extensions.
https://docs.rs/axum/0.7.5/axum/struct.Extension.html
extensionsに対するextractorとresponseらしいです。
ここで言うextensionsとは、 http crate のExtensionsのことです。
ドキュメントには以下のようにあります。
A type map of protocol extensions.
Extensions can be used by Request and Response to store extra data derived from the underlying protocol.
https://docs.rs/http/1.1.0/http/struct.Extensions.html
要するに「型情報をキーとしたMap(内部的には HashMap
)」「リクエストとレスポンスに付加的な情報を保持できる」というものです。
axumの Extension
は、ルーター/ミドルウェア/ハンドラーの間でデータを共有するために使います。
ミドルウェアとハンドラーの間でデータを共有したくなるケースとして最も多いものの一つに、認証機能(ログイン機能)があります。
「ハンドラー群に認証をかけたいが、ハンドラー側ではユーザ情報が必要だったり不要だったりする」という場合、 Extension
を使うことで簡潔に実装することができます。
今回はBearer認証を Extension
を使って実装することで、その使い方を解説します。
実装
コード全体は本記事最下部にあります。
ユーザプール
今回は簡易的なユーザプールとして HashMap
を使います。
keyはトークン文字列で、valueはユーザ名を持つ User
です。
サンプルのユーザプールを作る関数も定義しておきます。
pub type Token = String;
pub type UserMap = HashMap<Token, User>;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct User {
pub username: String,
}
pub fn build_user_map() -> UserMap {
let mut user_map = HashMap::new();
user_map.insert(
"aaa".to_string(),
User {
username: "Andy".to_string(),
},
);
user_map.insert(
"bbb".to_string(),
User {
username: "Bella".to_string(),
},
);
user_map.insert(
"ccc".to_string(),
User {
username: "Callie".to_string(),
},
);
user_map.insert(
"ddd".to_string(),
User {
username: "Daren".to_string(),
},
);
user_map
}
User
に FromRequestParts
を実装する
User
に FromRequestParts
を実装することで、ハンドラーの引数で User
を受け取ることができるようになります。
ここでは認証/認可は行われません。
処理としては、後述するミドルウェアで仕込んだ User
を取り出しています。
もし、リクエスト中に User
が無い場合、ミドルウェアが有効になっていないなどのバグのため、 expect
を使っています。
#[async_trait]
impl<S> FromRequestParts<S> for User
where
S: Send + Sync,
{
type Rejection = StatusCode;
async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
let user = parts
.extensions
.get::<Self>()
.expect("User not found. Did you add auth_middleware?");
Ok(user.clone())
}
}
ミドルウェアを実装する
このミドルウェアによって、認証が行われます。
処理としては以下の流れです。
- リクエストヘッダーからBearerトークンを取り出す
- ユーザプールから対応するユーザを探す
- ユーザが見つかればリクエストの
Extension
にUser
を追加する
3で追加した User
が前述の FromRequestParts
の実装で取り出されることになります。
pub async fn auth_middleware(
State(user_map): State<Arc<UserMap>>,
mut request: Request,
next: Next,
) -> axum::response::Result<Response> {
let bearer = request
.extract_parts::<TypedHeader<Authorization<Bearer>>>()
.await
.map_err(|_| StatusCode::BAD_REQUEST)?;
let token = bearer.token();
let user = user_map.get(token).ok_or(StatusCode::UNAUTHORIZED)?;
request.extensions_mut().insert(user.clone());
Ok(next.run(request).await)
}
ミドルウェアを使う
axum::middleware::from_fn_with_state
を使ってミドルウェアを組み込みます。
それ以外はaxumの一般的な使用方法です。
3つのハンドラーがあり、それぞれ以下の想定です。
-
public
: 認証が不要 -
private
: 認証が必要、ユーザ情報不要 -
your_name
: 認証が必要、ユーザ情報必要
private
ハンドラーの引数には User
がありませんが、ミドルウェアによって認証がかけられています。
#[tokio::main]
async fn main() {
let user_map = build_user_map();
let app = build_app(Arc::new(user_map));
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, app).await.unwrap();
}
#[rustfmt::skip]
fn build_app(user_map: Arc<UserMap>) -> Router {
let public_router = Router::new()
.route("/public", get(public));
let private_router = Router::new()
.route("/private", get(private))
.route("/your-name", get(your_name))
.route_layer(from_fn_with_state(user_map.clone(), auth_middleware));
Router::new()
.nest("/", public_router)
.nest("/", private_router)
.with_state(user_map)
}
async fn public() -> &'static str {
"This is public."
}
async fn private() -> &'static str {
"This is private."
}
async fn your_name(user: User) -> String {
format!("Your name is {}.", user.username)
}
動作確認
/public
は認証せずともアクセスが可能です。
$ curl http://localhost:3000/public
This is public.
/private
は認証が必要です。
$ curl -I http://localhost:3000/private
HTTP/1.1 400 Bad Request
content-length: 0
date: Mon, 27 May 2024 09:51:24 GMT
$ curl -H 'Authorization: Bearer aaa' http://localhost:3000/private
This is private.
/your-name
は認証が必要です。
$ curl -I http://localhost:3000/your-name
HTTP/1.1 400 Bad Request
content-length: 0
date: Mon, 27 May 2024 09:53:13 GMT
$ curl -H 'Authorization: Bearer bbb' http://localhost:3000/your-name
Your name is Bella.
補足
Extensionの可視性
Extension
は特定のリクエスト・レスポンスに固有のため、他のリクエスト・レスポンスには共有されません。
そのため、ミドルウェアで追加した User
が別のリクエストから見えることはありません。
Just locally in the request. Different requests don't share the same extensions.
https://github.com/tokio-rs/axum/discussions/2312#discussioncomment-7558458
Extensionの実装
Extension
は型情報をキーとする HashMap
になっており、興味深い実装です。
ミドルウェアの適用方法
今回はミドルウェアを組み込む際に from_fn_with_state
を使いましたが、他にも from_fn
や from_extractor
などの関数があります。
また、ミドルウェアの Router::layer
は順序が重要なため、一度 axum::middleware
のドキュメントを読むことをお勧めします。
まとめ
今回はBearer認証を実装することで Extension
の使い方を解説しました。
Extension
を使った手法は実際に、 axum-login crate でも使われています。
Extension
の利用ケースとしては、認証/認可の他にも「リクエストIDの生成・利用」「ロギングを行うミドルウェアとの情報共有」など、ミドルウェアとの連携が必要なものが考えられます。
汎用的な仕組みのため、axumを使いこなす上で重要な機能だと言えるでしょう。
コード全体
依存ライブラリ
axum = { version = "0.7.5" }
axum-extra = { version = "0.9.3", features = ["typed-header"] }
http = "1.1.0"
tokio = { version = "1", features = ["full"] }
auth.rs
use std::{collections::HashMap, sync::Arc};
use axum::{
async_trait,
extract::{FromRequestParts, Request, State},
middleware::Next,
response::Response,
RequestExt as _,
};
use axum_extra::{
headers::{authorization::Bearer, Authorization},
TypedHeader,
};
use http::{request::Parts, StatusCode};
pub type Token = String;
pub type UserMap = HashMap<Token, User>;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct User {
pub username: String,
}
pub fn build_user_map() -> UserMap {
let mut user_map = HashMap::new();
user_map.insert(
"aaa".to_string(),
User {
username: "Andy".to_string(),
},
);
user_map.insert(
"bbb".to_string(),
User {
username: "Bella".to_string(),
},
);
user_map.insert(
"ccc".to_string(),
User {
username: "Callie".to_string(),
},
);
user_map.insert(
"ddd".to_string(),
User {
username: "Daren".to_string(),
},
);
user_map
}
#[async_trait]
impl<S> FromRequestParts<S> for User
where
S: Send + Sync,
{
type Rejection = StatusCode;
async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
let user = parts
.extensions
.get::<Self>()
.expect("User not found. Did you add auth_middleware?");
Ok(user.clone())
}
}
pub async fn auth_middleware(
State(user_map): State<Arc<UserMap>>,
mut request: Request,
next: Next,
) -> axum::response::Result<Response> {
let bearer = request
.extract_parts::<TypedHeader<Authorization<Bearer>>>()
.await
.map_err(|_| StatusCode::BAD_REQUEST)?;
let token = bearer.token();
let user = user_map.get(token).ok_or(StatusCode::UNAUTHORIZED)?;
request.extensions_mut().insert(user.clone());
Ok(next.run(request).await)
}
main.rs
mod auth;
use std::sync::Arc;
use axum::{middleware::from_fn_with_state, routing::get, Router};
use auth::{auth_middleware, build_user_map, User, UserMap};
#[tokio::main]
async fn main() {
let user_map = build_user_map();
let app = build_app(Arc::new(user_map));
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, app).await.unwrap();
}
#[rustfmt::skip]
fn build_app(user_map: Arc<UserMap>) -> Router {
let public_router = Router::new()
.route("/public", get(public));
let private_router = Router::new()
.route("/private", get(private))
.route("/your-name", get(your_name))
.route_layer(from_fn_with_state(user_map.clone(), auth_middleware));
Router::new()
.nest("/", public_router)
.nest("/", private_router)
.with_state(user_map)
}
async fn public() -> &'static str {
"This is public."
}
async fn private() -> &'static str {
"This is private."
}
async fn your_name(user: User) -> String {
format!("Your name is {}.", user.username)
}
Discussion