【Python】FlaskサーバでJWT認証を実装する

2023/02/28に公開

概要

本記事では、JWT(JSON Web Token)を使用した認証を、flaskで実装する方法を紹介する。本記事のゴールとしては、localhostで立ち上げたflaskアプリケーション (http://127.0.0.1:5000/headers) に対して、Bearer Tokenを使用しGETを行う。そうすると、flaskアプリケーションでBearer Tokenが認証されて、200のステータスコードで "Access Granted" と出力される。

要件仕様

上記アプリケーションを実装する上で、三つの機能が求められる。

  • get_token_auth_header関数
    • この関数は、Authorizationヘッダーからアクセストークンを取得する。
    • Authorizationヘッダーがない場合、またはヘッダーが不正な場合は、AuthErrorを出力する。
    • Authorizationヘッダーは、下記のようにflaskアプリケーションに対し送信される。
  • verify_decode_jwt関数
    • この関数は、JWTの検証と複合化を行う。
    • Auth0からトークンを主とし、それが有効であることを確認するため、使用される。
  • requires_authデコレータ
    • このデコレータは、verify_decode_jwt関数を使用して、トークンを検証する
    • 検証に失敗した場合、401エラーを返す
    • また、最後にheadersエンドポイントを定義する。このエンドポイントは、認証が必要なエンドポイントで、アクセストークンが検証された後に、アクセスを許可する。
    • アクセスが許可されると、"Access Granted" と表示される。
    • エンドポイントとは "/headers" のことで、ここにアクセスする際にはJWT認証を成功する必要がある。

ハンズオン

では、実際に要件仕様に沿ってコードを実装する。

get_token_auth_header関数

こちらは request.headers.get('Authorization', None) により、認証情報を取得している。認証情報が取得できなければ、エラーを返す仕様である。

def get_token_auth_header():
    """Obtains the Access Token from the Authorization Header
    """
    auth = request.headers.get('Authorization', None)
    if not auth:
        raise AuthError({
            'code': 'authorization_header_missing',
            'description': 'Authorization header is expected.'
        }, 401)

認証情報取得後、それがbearer認証であるか確認している。bearer xxxxxxという形でJWTを受け取るため、parts[0].lower() != 'bearer': というようにトークンをチェックしている。また、上記の形でトークンを受け取るため、len(parts)が1もしくは2以上となることはあり得ない。そのため、上記場合はエラーを返す使用である。

    parts = auth.split()
    if parts[0].lower() != 'bearer':
        raise AuthError({
            'code': 'invalid_header',
            'description': 'Authorization header must start with "Bearer".'
        }, 401)
        
    elif len(parts) == 1:
        print("2")
        raise AuthError({
            'code': 'invalid_header',
            'description': 'Token not found.'
        }, 401)

    elif len(parts) > 2:
        print("3")
        raise AuthError({
            'code': 'invalid_header',
            'description': 'Authorization header must be bearer token.'
        }, 401)

最終的に、bearer xxxxxx のトークン部分が取得できれば良いので、token = parts[1] としてそのtokenをreturnしている。

    token = parts[1]
    return token

verify_decode_jwt関数

ここでは、JWTの署名を検証して、有効なトークンであることを確認して、トークンのペイロードを返す関数を定義している。JWTはヘッダー、ペイロード、署名で構成されており、ペイロードにはユーザに関わる情報が格納されている。

ユーザはまず、公開鍵を取得するために、AUTH0_DOMAINで定義されたドメインから、JSON Web Key Setを取得する。JWK Setとは、JWTトークンを検証するために使用される公開鍵と秘密鍵のセットである。JWT Setには、以下情報が含まれる。

  • keys: JSON Web Keyの配列で、鍵の種類(RSA、ECDSA、HMACなど)、鍵のID(kid)、公開鍵または秘密鍵の値、鍵の使用目的
    jsonurl = urlopen(f'https://{AUTH0_DOMAIN}/.well-known/jwks.json')
    jwks = json.loads(jsonurl.read())

そして、トークンヘッダー情報を解析して、JWTが有効な署名アルゴリズムを使用して署名しているか確認する。次に、トークンのヘッダーに含まれる鍵のID(kid)がJWKSの中にあることを確認して、対応するRSA公開鍵を見つける。

    unverified_header = jwt.get_unverified_header(token)
    rsa_key = {}
    if 'kid' not in unverified_header:
        raise AuthError({
            'code': 'invalid_header',
            'description': 'Authorization malformed.'
        }, 401)
	
    for key in jwks['keys']:
        if key['kid'] == unverified_header['kid']:
            rsa_key = {
                'kty': key['kty'],
                'kid': key['kid'],
                'use': key['use'],
                'n': key['n'],
                'e': key['e']

RSA 公開鍵を使用してトークンを検証し、以下を確認する。

  • トークンの優子期限が切れていないこと
  • 目的の受信者と発行者が正しいこと
  • もし上記が無効であれば、AuthErrorを発生させる
    if rsa_key:
        try:
            payload = jwt.decode(
                token,
                rsa_key,
                algorithms=ALGORITHMS,
                audience=API_AUDIENCE,
                issuer='https://' + AUTH0_DOMAIN + '/'
            )

            return payload
	    
        except jwt.ExpiredSignatureError:
            raise AuthError({
                'code': 'token_expired',
                'description': 'Token expired.'
            }, 401)

        except jwt.JWTClaimsError:
            raise AuthError({
                'code': 'invalid_claims',
                'description': 'Incorrect claims. Please, check the audience and issuer.'
            }, 401)
        except Exception:
            raise AuthError({
                'code': 'invalid_header',
                'description': 'Unable to parse authentication token.'
            }, 400)
    raise AuthError({
                'code': 'invalid_header',
                'description': 'Unable to find the appropriate key.'
            }, 400)

requires_authデコレータ

get_token_auth_header()関数でtokenを取得して、そのtokenをverify_decode_jwt関数で複合化しreturnしている。

def requires_auth(f):
    @wraps(f)
    def wrapper(*args, **kwargs):
        token = get_token_auth_header()
        try:
            payload = verify_decode_jwt(token)
        except:
            abort(401)
        return f(payload, *args, **kwargs)

    return wrapper

headersエンドポイント

最後に、headersエンドポイントを定義している。このエンドポイントは、認証が必要なエンドポイントのため、アクセストークンが認証された後にアクセスが許可される。許可されると、"Access Granted" というメッセージが返される。

@app.route('/headers')
@requires_auth
def headers(payload):
    print(payload)
    return 'Access Granted'

コード全体

from flask import Flask, request, abort
import json
from functools import wraps
from jose import jwt
from urllib.request import urlopen
import ssl
ssl._create_default_https_context = ssl._create_unverified_context


app = Flask(__name__)

AUTH0_DOMAIN = 'dev-sq7mzrpnl8jmmik1.jp.auth0.com'
ALGORITHMS = ['RS256']
API_AUDIENCE = 'image'


class AuthError(Exception):
    def __init__(self, error, status_code):
        self.error = error
        self.status_code = status_code


def get_token_auth_header():
    """Obtains the Access Token from the Authorization Header
    """
    auth = request.headers.get('Authorization', None)
    if not auth:
        raise AuthError({
            'code': 'authorization_header_missing',
            'description': 'Authorization header is expected.'
        }, 401)

    parts = auth.split()
    if parts[0].lower() != 'bearer':
        print("1")
        raise AuthError({
            'code': 'invalid_header',
            'description': 'Authorization header must start with "Bearer".'
        }, 401)
        
    elif len(parts) == 1:
        print("2")
        raise AuthError({
            'code': 'invalid_header',
            'description': 'Token not found.'
        }, 401)

    elif len(parts) > 2:
        print("3")
        raise AuthError({
            'code': 'invalid_header',
            'description': 'Authorization header must be bearer token.'
        }, 401)

    token = parts[1]
    return token


def verify_decode_jwt(token):
    jsonurl = urlopen(f'https://{AUTH0_DOMAIN}/.well-known/jwks.json')
    jwks = json.loads(jsonurl.read())
    unverified_header = jwt.get_unverified_header(token)
    rsa_key = {}
    if 'kid' not in unverified_header:
        raise AuthError({
            'code': 'invalid_header',
            'description': 'Authorization malformed.'
        }, 401)

    for key in jwks['keys']:
        if key['kid'] == unverified_header['kid']:
            rsa_key = {
                'kty': key['kty'],
                'kid': key['kid'],
                'use': key['use'],
                'n': key['n'],
                'e': key['e']
            }
    if rsa_key:
        try:
            payload = jwt.decode(
                token,
                rsa_key,
                algorithms=ALGORITHMS,
                audience=API_AUDIENCE,
                issuer='https://' + AUTH0_DOMAIN + '/'
            )

            return payload

        except jwt.ExpiredSignatureError:
            raise AuthError({
                'code': 'token_expired',
                'description': 'Token expired.'
            }, 401)

        except jwt.JWTClaimsError:
            raise AuthError({
                'code': 'invalid_claims',
                'description': 'Incorrect claims. Please, check the audience and issuer.'
            }, 401)
        except Exception:
            raise AuthError({
                'code': 'invalid_header',
                'description': 'Unable to parse authentication token.'
            }, 400)
    raise AuthError({
                'code': 'invalid_header',
                'description': 'Unable to find the appropriate key.'
            }, 400)


def requires_auth(f):
    @wraps(f)
    def wrapper(*args, **kwargs):
        token = get_token_auth_header()
        print(verify_decode_jwt(token))
        try:
            payload = verify_decode_jwt(token)
        except:
            abort(401)
        return f(payload, *args, **kwargs)

    return wrapper

@app.route('/headers')
@requires_auth
def headers(payload):
    print(payload)
    return 'Access Granted'

参考

https://www.udacity.com/course/full-stack-web-developer-nanodegree--nd0044
https://www.engineer-memo.net/20180716-4614

Discussion