Open8

axum0.6 -> axum0.7 migrationメモ

yunayuna

各種モジュールのversion up対応に合わせて、axumのバージョンが0.6から0.7にアップしました。
hyper 1.0
http 1.0
http-body 1.0
tower-http 0.5

https://github.com/tokio-rs/axum/releases/tag/axum-v0.7.0

これに伴うbreaking changeが多く発生しているので、
本番稼働中のプロダクトについて、migrationしながら対応した内容をメモしていきます。

各種モジュールを、以下のバージョン(2023/11/29 時点の最新)にして、migration開始

Cargo.toml

[dependencies]
axum = { version = "0.7.1", features = ["ws", "multipart", "macros"] }
axum-extra = "0.9.0"
axum-core = "0.4.0"
hyper = { version = "1.0.1", features = ["full"] }
tower = { version = "0.4.12", features = ["full"] }
tower-http = { version = "0.5.0", features = ["full"] }

yunayuna

breaking: axum no longer re-exports hyper::Body as that type is removed
in hyper 1.0. Instead axum has its own body type at axum::body::Body (#1751)

hyper::Bodyを使っていた箇所を修正

router.rs
// fn handle_panic(err: Box<dyn std::any::Any + Send + 'static>) -> axum::response::Response<hyper::Body> {
fn handle_panic(err: Box<dyn std::any::Any + Send + 'static>) -> axum::response::Response<axum::body::Body> {
    let details = if let Some(s) = err.downcast_ref::<String>() {
        s.clone()
    } else if let Some(s) = err.downcast_ref::<&str>() {
        s.to_string()
    } else {
        "Unknown panic message".to_string()
    };

    let body = serde_json::json!({
        "error": {
            "kind": "panic",
            "details": details,
        }
    });
    let body = serde_json::to_string(&body).unwrap();

    Response::builder()
        .status(StatusCode::INTERNAL_SERVER_ERROR)
        .header(header::CONTENT_TYPE, "application/json")
//        .body(hyper::Body::from(body))
        .body(axum::body::Body::from(body))
        .unwrap()
}
yunayuna

breaking: Removed re-exports of Empty and Full. Use
axum::body::Body::empty and axum::body::Body::from respectively (#1789)

breaking: The following types from http-body no longer implement IntoResponse:
Full, use Body::from instead
Empty, use Body::empty instead
BoxBody, use Body::new instead
UnsyncBoxBody, use Body::new instead
MapData, use Body::new instead
MapErr, use Body::new instead

hyper_body::Fullがなくなったので、
hyperのコードを読んで、真似をしながら http_body_util::Fullに変換してみた。
※この対応は暫定対応か?もっと良いやり方あるかも。http_body_util crateの位置づけが不明なので、できればhttp_body_utilを使わずに実装したい

router.rs
#[derive(Debug, Default, Clone, Copy)]
struct PanicHandler {}
impl ResponseForPanic for PanicHandler {
//    type ResponseBody = http_body::Full<bytes::Bytes>;
    type ResponseBody = http_body_util::Full<bytes::Bytes>;

    fn response_for_panic(
        &mut self,
        err: Box<dyn std::any::Any + Send + 'static>,
    ) -> hyper::Response<Self::ResponseBody> {
        error!("panic. type_of err: {:?}", type_of(&err));

        let mut error_msg = format!("{:?}", err);
        if let Some(s) = err.downcast_ref::<String>() {
            println!("Service panicked: {}", s);
            error_msg.push_str(format!("Service panicked: {}", s).as_str());
            tracing::error!("Service panicked: {}", s);
        } else if let Some(s) = err.downcast_ref::<&str>() {
            println!("Service panicked: {}", s);
            error_msg.push_str(format!("Service panicked: {}", s).as_str());
            tracing::error!("Service panicked: {}", s);
        } else {
            println!("Service panicked but `CatchPanic` was unable to downcast the panic info");
            error_msg.push_str(
                "Service panicked but `CatchPanic` was unable to downcast the panic info",
            );
            tracing::error!(
                "Service panicked but `CatchPanic` was unable to downcast the panic info"
            );
        };

        let mut res = Response::new(Full::from(error_msg));
        res.headers_mut().insert(
            header::CONTENT_TYPE,
            header::HeaderValue::from_static("text/plain; charset=utf-8"),
        );
        *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;

        res
    }
}
yunayuna

breaking: Removed the BoxBody type alias and its box_body
constructor. Use axum::body::Body::new instead (#1789)

file streamを取得する関数を修正

axum.rs
use axum::response::Response;
// use axum_core::body::boxed;

// pub fn gen_file_stream(body: hyper::Body, file_name: &str) -> Response<axum_core::body::BoxBody> {
pub fn gen_file_stream(body: axum::body::Body, file_name: &str) -> Response<axum::body::Body> {
    let mut response = Response::builder().body(axum::body::Body::new(body)).unwrap();

    let headers = response.headers_mut();
    headers.insert(
        hyper::header::CONTENT_TYPE,
        hyper::header::HeaderValue::from_static("application/octet-stream"),
    );
    headers.insert(
        hyper::header::CONTENT_DISPOSITION,
        hyper::header::HeaderValue::from_str(
            format!(
                "attachment; filename=\"{}\"",
                urlencoding::encode(file_name)
            )
            .as_str(),
        )
        .unwrap(),
    );
    response
}

let body = hyper::Body::wrap_stream(stream)

let body = axum::body::Body::from_stream(stream);

image.rs
#[tracing::instrument]
pub async fn get_image_file(
    Extension(ref database): Extension<crate::infrastructure::db::Database>,
    Query(query): Query<serde_json::Value>,
) -> impl IntoResponse {
    let file_id = query.get("file").unwrap().as_str().unwrap();

    // let region_provider = RegionProviderChain::default_provider().or_else("ap-northeast-1");
    let config = aws_config::from_env().region("us-xxxxx").load().await;
    info!("config: {:?}", config);
    let aws_client = aws_sdk_s3::Client::new(&config);

    let data: AggregatedBytes = get_object(
        aws_client,
        "xxxxxxx",
        format!("images/{}", &file_id).as_str(),
    )
    .await
    .expect("s3の取得に失敗");

    // data.into_bytes()
    let reader = tokio::io::BufReader::new(std::io::Cursor::new(data.into_bytes()));
    let stream = tokio_util::io::ReaderStream::new(reader);
    //let body = hyper::Body::wrap_stream(stream);
    let body = axum::body::Body::from_stream(stream);

    service::axum::gen_file_stream(body, file_id)
}
yunayuna

breaking: Removed axum::Server as it was removed in hyper 1.0. Instead
use axum::serve(listener, service) or hyper/hyper-util for more configuration options (#1868)

axumページのusage exampleを参考にして書き直し
https://github.com/tokio-rs/axum/tree/axum-v0.7.0#usage-example

axum::Serverではなく、axum::serve::Serveを返す

router.rs
//pub async fn build_server() -> axum::Server<AddrIncoming, IntoMakeService<Router>> {
pub async fn build_server() -> axum::serve::Serve<Router, Router> {
    let (conf, router) = configure();
    let database = crate::infrastructure::db::init_database().await;
    let default_hook = std::panic::take_hook();
    std::panic::set_hook(Box::new(move |panic_info| {
        error!("panic_info! at:{:?}", panic_info.location());
        default_hook(panic_info);
    }));

    // Compose the routes
    let app = router.layer(
        ServiceBuilder::new()
            .layer(CatchPanicLayer::custom(handle_panic))
            .layer(HandleErrorLayer::new(|error: BoxError| async move {
                if error.is::<tower::timeout::error::Elapsed>() {
                    Ok(StatusCode::REQUEST_TIMEOUT)
                } else {
                    Err((
                        StatusCode::INTERNAL_SERVER_ERROR,
                        format!("Unhandled internal error: {}", error),
                    ))
                }
            }))
            .timeout(Duration::from_secs(10))
            .layer(TraceLayer::new_for_http())
            .layer(Extension(database))
            .layer(
                CompressionLayer::new().compress_when(
                    tower_http::compression::DefaultPredicate::new()
                        .and(NotForContentType::new("text/event-stream")),
                ),
            )
            .layer(DefaultBodyLimit::max(2 * 1000 * 1000 * 1000)) //max length: 2GB
            .layer(middleware::from_fn(controller::layer::tracing::tracing_machine_id))
            .into_inner(),
    );

    let address: SocketAddr = SocketAddr::new(conf.address, conf.port);
    info!("listening on {}", address);

    // let service = app.into_make_service();
    // axum::Server::try_bind(&address)
    //     .unwrap_or_else(|e| panic!("Error binding to '{}' - {}", address, e))
    //     .serve(service)

   let listener = tokio::net::TcpListener::bind(address).await.unwrap();
   let serve = axum::serve(listener, app);
   serve
}
yunayuna

breaking: The following types/traits are no longer generic over the request body
(i.e. the B type param has been removed) (#1751 and #1789):
FromRequestParts
FromRequest
HandlerService
HandlerWithoutStateExt
Handler
LayeredFuture
Layered
MethodRouter
Next
RequestExt
RouteFuture
Route
Router

権限チェック用のmiddleware 関数

Next<B> のようにNextにジェネリクスは付かなくなった。
合わせて、引数をhyper::Request<B>から、axum::extract::Request に変更。

auth.rs
//use hyper::Request;
use axum::extract::{FromRequest, Request};

//pub async fn admin_authenticate<B>(
//    mut req: Request<B>,
//    next: Next<B>,
pub async fn admin_authenticate(
    mut req: Request,
    next: Next,
) -> Result<Response, StatusCode> {
    println!(
        "admin_authenticate start. database: {:?}",
        req.extensions().get::<Database>()
    );
    let database = req
        .extensions()
        .get::<Database>()
        .ok_or(StatusCode::UNAUTHORIZED)?;
    println!("admin_authenticate end. database: {:?}", database);

    // extract the authorization header
    let auth_header = req
        .headers()
        .get(axum::http::header::AUTHORIZATION)
        .and_then(|header| header.to_str().ok());

    println!("auth headerat {:?}: {:?}", req.uri(), auth_header);
    let token = match auth_header {
        Some(header) if header.starts_with("Bearer ") => header.replace("Bearer ", ""),
        _ => {
            warn!("正しい形式の認証ヘッダが取得できませんでした。");
            return Err(StatusCode::UNAUTHORIZED);
        }
    };

    match verify_jwt(&token, &KEYS.decoding) {
        Ok(claims) => {
            // insert the current user into a request extension so the handler can
            // extract it
            info!("auth claims: {:?}", claims.clone());

            match repository::free::user::get_user_with_error(database, &claims).await {
                Ok(user) if user.role == Role::AllAdmin => {
                    let extensions = req.extensions_mut();
                    extensions.insert(claims);
                    extensions.insert(user.clone());
                    let database = extensions
                        .get_mut::<Database>()
                        .ok_or(StatusCode::UNAUTHORIZED)?;
                    set_authorized_organization(&user, database);
                }
                Ok(user) => {
                    error!("管理者権限が存在しません。{:?}", user);
                    return Err(StatusCode::UNAUTHORIZED);
                }
                Err(e) => {
                    error!("認証失敗しました。{:?}", e);
                    return Err(StatusCode::UNAUTHORIZED);
                }
            }
            Ok(next.run(req).await)
        }
        Err(e) => {
            error!("認証失敗しました。{:?}", e);
            Err(StatusCode::UNAUTHORIZED)
        }
    }
}
yunayuna

まだreqwestがhyper1.0対応完了してないので
https://github.com/seanmonstar/reqwest/issues/2039

対応しました!(20240321)
https://seanmonstar.com/blog/reqwest-v012/

以下は古い情報ですが記録のため残しておきます

reqwestとの連携してるところは若干面倒。

例:
reqwestが返すStatusCodeやHeadearMapは、
axum0.7の axum::response::IntoResponse に対応してないので、
hyper1.0の hyper::StatusCode、hyper::HeaderMapに詰め直す。

reqwestがhyper1.0対応完了したら戻す。

pub async fn sample_s3_request_by_reqwest(
    Extension(ref _database): Extension<crate::infrastructure::db::Database>,
    Query(query): Query<serde_json::Value>,
) -> ApiResult<impl IntoResponse> {
    //s3にそのままプロキシ
    let client = reqwest::Client::builder()
        .use_rustls_tls()
        .build()
        .unwrap();

    let bucket_name = "xxxxxxxx";
    let region = "us-xxxxxxx";
    let res = client
        .get(format!(
            "https://{}.s3.{}.amazonaws.com/",
            bucket_name, region
        ))
        .query(&query)
        .timeout(Duration::from_secs(45))
        .send()
        .await?;

    let status_code = res.status();
    //reqwestのhyper1.0対応まで暫定的に追加 start //////////////////////////////////////////////////////////////////////////
    let status_code = conv_old_reqwest_to_axum_statuscode(status_code);
    let header_map: hyper::HeaderMap = conv_old_reqwest_to_axum_header_map(res.headers().clone());
    //reqwestのhyper1.0対応まで暫定的に追加 end   //////////////////////////////////////////////////////////////////////////
    Ok((
        status_code,
        //res.headers().clone(),
        header_map,
        res.bytes().await.unwrap(),
    ))
}

//暫定的に変換用の関数を準備
pub fn conv_old_reqwest_to_axum_statuscode(old_status_code: reqwest::StatusCode) -> axum::http::StatusCode {
    let status_code = hyper::StatusCode::from_u16(old_status_code.as_u16()).unwrap();
    status_code
}
pub fn conv_old_reqwest_to_axum_header_map(old_header_map: reqwest::header::HeaderMap) -> axum::http::HeaderMap {
    let header_map: hyper::HeaderMap = hyper::HeaderMap::from_iter(old_header_map.iter().filter_map(
        |(k,v)| { return hyper::header::HeaderName::from_bytes(k.as_str().as_bytes()).ok().zip(axum::http::HeaderValue::from_bytes(v.as_bytes()).ok())}));
    header_map
}