👋

actix-identityのログイン状態のチェックをMiddlewareで行いたい

2021/02/25に公開

RustのWeb Frameworkのactix-webでユーザー認証の機能を入れようと考えると、最初に出てくるのがactix-identityだと思います。

そのactix-identityのログイン状態により、アクセス可否を判断するMiddlewareを作ろうと思いました。(コントローラ側で毎回判断するのは冗長ですし)

方針、参考サイト

最初に書いた通りactix-webのMiddlewareを作りたいと思います。
ただ、右も左もわからない初心者。とりあえず、公式サイトのサンプルを改造して作る事にします。

苦労したところ

Rustはトレイトとかライフタイムとか理解するのが大変な部分がありますよね(´・ω・`)
今回でういうと

impl<S, B> Service for SayHiMiddleware<S>
where
    S: Service<Request = ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
    S::Future: 'static,
    B: 'static,
{
    type Request = ServiceRequest;
    type Response = ServiceResponse<B>;
    type Error = Error;
    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;

この辺で、 Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>> とかが なぜこの型なのか と頭悩ませます。

この型で返そうと色々頑張ってみたのですが、どうにもうまくいかず、色々検索をしているとこんなものを見つけました。

impl<S, B> Service for CheckLoginMiddleware<S>
where
    S: Service<Request = ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
    S::Future: 'static,
{
    type Request = ServiceRequest;
    type Response = ServiceResponse<B>;
    type Error = Error;
    type Future = Either<S::Future, Ready<Result<Self::Response, Self::Error>>>;

え?なんで Either 返していいるの?なぜこれがOKなの?疑問だらけでしたが、とりあえずこちらのソースを参考に修正していきました。

想定通りの動作をするようになったのですが、なんであちこち 'satic になってるんでしょうか?
'static = プロセスある間メモリ解放しない。 みたいなイメージがあって、思わず500万回アクセスさせてみてメモリ使用量みたりしてました。(結果問題ないぽいのですが・・・)

そのあとこのページ で「 'static が常に「永遠に生存する」ことを意味している、というのはよくある誤解です。 」と書いてありました。
Rust難しい(´・ω・`) 'static の勉強もしないとですね・・・

ソースコード

最終的にこんな感じのソースコードになりました。
middleware::IdentityFilterService::new でMiddlewareの初期化をします。
引数として

  • ログイン済みかチェックするPATHを列挙したVec。チェック自体は前方一致で判断。
  • チェックの結果エラーの場合、リダイレクトするのか401のエラーを返すのか、IdentityFilterTypeのenum設定
    の二つの引数を取ります。

middleware.rs

use actix_identity::RequestIdentity;
use actix_web::{Error, HttpResponse, dev::{Service, ServiceRequest, ServiceResponse, Transform}, http, web};
use std::{task::{Context, Poll}};
use tera::Tera;
use futures::future::{Either, Ready, ok};

#[derive(Clone, Debug)]
pub struct IdentityFilterService {
    pub target_path_list: Vec<String>,
    pub filter_type: IdentityFilterType,
}

#[derive(Clone, Debug)]
pub enum IdentityFilterType {
    Redirect(String),
    Unauthorized(String),
}

impl IdentityFilterService {
    pub fn new(target_path_list: Vec<impl Into<String>>, filter_type: IdentityFilterType) -> Self {
        let target_path_list = target_path_list.into_iter().map(|i| i.into()).collect();

        IdentityFilterService {
            target_path_list,
            filter_type,
        }
    }
}

impl<S, B> Transform<S> for IdentityFilterService
where
    S: Service<Request = ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
    S::Future: 'static,
{
    type Request = ServiceRequest;
    type Response = ServiceResponse<B>;
    type Error = Error;
    type InitError = ();
    type Transform = IdentityFilterServiceMiddleware<S>;
    type Future = Ready<Result<Self::Transform, Self::InitError>>;

    fn new_transform(&self, service: S) -> Self::Future {
        ok(IdentityFilterServiceMiddleware {
            service,
            p: self.clone(),
        })
    }
}

pub struct IdentityFilterServiceMiddleware<S> {
    service: S,
    p: IdentityFilterService,
}

impl<S, B> Service for IdentityFilterServiceMiddleware<S>
where
    S: Service<Request = ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
    S::Future: 'static,
{
    type Request = ServiceRequest;
    type Response = ServiceResponse<B>;
    type Error = Error;
    // type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
    type Future = Either<S::Future, Ready<Result<Self::Response, Self::Error>>>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.service.poll_ready(cx)
    }

    fn call(&mut self, req: ServiceRequest) -> Self::Future {
        let ident = req.get_identity();

        /*
         * アクセス許可判断
         */

        // ログイン済みだったらOK
        let mut result = ident.is_some();
        if !result {
            // 非ログイン時、要ログインページのpathかチェック
            let path_check = self.p.target_path_list.iter().find(|x| {
                req.path().starts_with(*x)
            });
            if path_check.is_none() {
                // 非ログインでも、要ログインページでなければOK
                result = true;
            }
        }

        /*
         * ページ遷移
         */

        if result {
            // 通常ページ表示
            let fut = self.service.call(req);

            Either::Left(fut)
        } else {
            // アクセスNG - フィルターのタイプに結果を出しわける
            let result: ServiceResponse<B> = match self.p.filter_type.clone() {
                IdentityFilterType::Redirect(path) => {

                    req.into_response(
                        HttpResponse::Found()
                            .header(http::header::LOCATION, path)
                            .finish()
                            .into_body(),
                    )
                
                }
                IdentityFilterType::Unauthorized(template) => {
                    let tera = req.app_data::<web::Data<Tera>>().unwrap();
                    let ctx = tera::Context::new();
                    let html = tera.render(&template, &ctx).unwrap();
                    
                    req.into_response(
                        HttpResponse::Unauthorized()
                            .content_type("text/html")
                            .body(html)
                            .into_body(),
                    )
                }
            };

            Either::Right(ok(result))
        }
    }
}

main.rs

/*
いろいろ省略
*/

#[actix_web::main]
async fn main() -> std::io::Result<()> {
    /*
    いろいろ省略
    */

    let templates = Tera::new("templates/**/*").unwrap();

    HttpServer::new(move || {
        App::new()
            .wrap(middleware::IdentityFilterService::new(
                vec!["/member/"],
                middleware::IdentityFilterType::Unauthorized("unauthorized.html".to_owned()),
            ))
            .data(templates.clone())
            .wrap(IdentityService::new(
                SessionIdentiyPolicy::new().key("Identity"),
            ))
            .wrap(RedisSession::new(
                &CONFIG.session_server,
                CONFIG.secret_key.as_bytes(),
            ))

    /*
    いろいろ省略
    */

最後に

Rustでは型関係とライフタイムはほんと初心者にとってのハードルですね。
型違うのになんか代入できるのなんで?とか不思議に思うこともしばしば。

あと impl Into<String> でstrとStringのどちらでも引数として取れると知ってすごくうれしかったです。
"hogehoge".to_string() の .to_string()というのをよく書いていたので。。。

やればやるほど Rustなんもわからん て気持ちになります( ;∀;)

Discussion