🏎️

FastAPIでCRUD処理

2024/12/13に公開

Schema定義

userデータにおけるCRUD処理に必要なschemaを定義します。

コードはこちら
from sqlmodel import Relationship, SQLModel, Field, AutoString
from typing import Optional, TYPE_CHECKING
from pydantic import EmailStr
import uuid
from datetime import datetime

from ..schema.press import PressReleaseReadWithUser

if TYPE_CHECKING:
    from .press import PressRelease


class BaseUser(SQLModel):
    username: str = Field(nullable = False, default = None, unique = True)
    email: EmailStr =Field(unique = True, index = True, sa_type=AutoString)
    is_active: bool = Field(default = True)
    is_superuser: bool = Field(default = False)


class User(BaseUser, table = True):
    id: Optional[int] = Field(default = None, primary_key = True)
    uuid: str = Field(default_factory = uuid.uuid4, nullable = False)
    hashed_password: str = Field(nullable = False)
    created_at: datetime = Field(default = datetime.now(), nullable = False)
    updated_at: datetime = Field(default_factory = datetime.now, nullable = False, sa_column_kwargs = {'onupdate': datetime.now})
    refresh_token: str = Field(nullable = True)
    
    press_releases: list["PressRelease"] | None = Relationship(back_populates="user", cascade_delete=True)

class UserOut(BaseUser):
    id: int
    uuid: str
    created_at: datetime
    updated_at: datetime
    refresh_token: str

class UserLogin(SQLModel):
    email: str
    password: str

class UserCreate(SQLModel):
    username: str
    email: EmailStr
    password: str

class UserRead(BaseUser):
    id: int
    uuid: str
    created_at: datetime
    updated_at: datetime
    press_releases: list["PressReleaseReadWithUser"] = []

class UserUpdate(BaseUser):
    username: Optional[str] = None
    email: Optional[EmailStr] = None
    is_active: Optional[bool]
    is_superuser: Optional[bool]

【ログイン処理】

  • ログイン処理(UserLogin)は、emailとpasswordで行います
  • サインアップ処理(UserCreate)は、email、passwordに加えて、usernameを必要としています

【データの呼び出し】

  • BaseUserを承継し、登録後に設定されるid、uuid、created_atなどの項目を取得します

【更新処理】

  • 更新する値は、primary_keyにしているidと個別のidとして割り当てているuuidを除く項目にしています

CRUD処理に必要なロジック

CRUD処理のロジックをまとめています。
まずは、Userに関するロジックです。

コードはこちら
import os
from fastapi import HTTPException, status
from typing import List

# schema
from ..schema.user import User, UserOut

# sqlmodel
from sqlmodel import select

# uuid
import uuid as uuid_pkg

# settings
from ..core.config import settings

# pandas
import pandas as pd

# Userの登録確認用
def check_user(session, user):
    checked_user = session.exec(
        select(User)
        .where(User.username == user.username)
        .where(User.email == user.email)
    ).first()
    return checked_user

# hash化されたパスワード生成用
def get_hashed_password(password):
    return settings.pwd_context.hash(password)

# パスワードチェック
def verfy_password(password, hashed_password):
    return settings.pwd_context.verify(password, hashed_password)

# すべてのUserを取得する
def get_all_user(session):
    all_user = select(User)
    user_list: List[UserOut] = session.exec(all_user)

# Userの登録情報を更新
def update_user(session, uuid, user):
    db_user = session.exec(select(User).where(User.uuid == uuid)).first()
    
    if not db_user:
        raise HTTPException(
            status_code=status.HTTP_404_NOT_FOUND,
            detail = "User not found."
        )
    
    user_data = user.model_dump(exclude_unset = True)

    for key, value in user_data.items():
        setattr(db_user, key, value)
    
    session.add(db_user)
    session.commit()
    session.refresh(db_user)
    return db_user

# Userの削除(uuidでUserを呼び出し)
def delete_user(session, uuid):
    user = get_user_by_uuid(session, uuid)
    
    if not user:
        raise HTTPException(
            status_code=status.HTTP_404_NOT_FOUND,
            detail = "User not found."
        )
    
    session.delete(user)
    session.commit()
    return {"message": "User was deleted."}

# uuidでUserを取得
def get_user_by_uuid(session, uuid: str):
    return session.query(User).filter(User.uuid == uuid).first()

# emailでUserを取得
def get_user_by_email(session, email: str):
    return session.query(User).filter(User.email == email).first()


# 一括登録用のメソット(pandasを使用)
def users_register(session, file_path: str):
    failed_users = []

    if file_path.endswith(".csv"):
        df = pd.read_csv(file_path)
    elif file_path.endswith(".xlsx"):
        df = pd.read_excel(file_path)
    else:
        raise ValueError(
            "Invalid file format. Please provide a CSV or EXCEL file."
        )
    
    for index, row in df.iterrows():
        # エクセルのA列: username, B列: email, C列:password, D列:is_active, E列: is_superuser
        username = row["username"]
        email = row["email"]
        password = row["password"]
        is_active = row["is_active"]
        is_superuser = row["is_superuser"]

        if pd.isna(username) or pd.isna(email) or pd.isna(password):
            failed_users.append({
                "username": username,
                "error": "Username, email, password are required fired."
            })
            continue

        try:
            hashed_password = get_hashed_password(password)

            # 既存ユーザーの検索
            existing_user = session.query(User).filter((User.username == username) | (User.email == email)).first()

            if existing_user:
                # ユーザーが存在する場合、情報を更新(空欄の場合はその項目をスキップ)
                if not pd.isna(username):
                    existing_user.username = username
                if not pd.isna(email):
                    existing_user.email = email
                if not pd.isna(password):
                    existing_user.hashed_password = hashed_password
                if not pd.isna(is_active):
                    existing_user.is_active = is_active
                if not pd.isna(is_superuser):
                    existing_user.is_superuser = is_superuser
            
            else:
                # ユーザーが存在しない場合、新規作成
                new_user = User(
                    uuid = str(uuid_pkg.uuid4()),
                    username = username,
                    email = email,
                    hashed_password=hashed_password,
                    is_active=is_active,
                    is_superuser=is_superuser,
                )

                session.add(new_user)

        except Exception as e:
            failed_users.append({
                "username": username,
                "error": str(e)
            })
    session.commit()

    return failed_users

次に、認証に関するロジックです。

コードはこちら
import os
from fastapi import Depends, HTTPException, status
from datetime import datetime, timedelta, timezone

# auth
from jose import jwt, JWTError

# sqlmodel
from sqlmodel import Session

# db
from ..db.db import get_session

# settings
from ..core.config import settings

# functions
from .user import verfy_password, get_user_by_uuid, get_user_by_email, get_all_user

# schemas
from ..schema.user import TokenData

# .env
from dotenv import load_dotenv
load_dotenv()


# emailでの認証
def authenticate(session, email, password):
    user = get_user_by_email(session, email)
 
    if verfy_password(password, user.hashed_password) != True:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Can't authentiate by your password."
        )
    return user

# tokenの生成
def create_tokens(uuid: str):
    access_token_payload = {
        "token_type": "access_token",
        "exp": datetime.now(timezone.utc) + timedelta(days=int(os.environ["ACCESS_TOKEN_EXPIRE_DAYS"])),
        "uuid": uuid,
    }

    refresh_token_payload = {
        "token_type": "refresh_token",
        "exp": datetime.now(timezone.utc) + timedelta(days=int(os.environ["REFRESH_TOKEN_EXPIRE_DAYS"])),
        "uuid": uuid,
    }

    # create token
    access_token = jwt.encode(
        access_token_payload,
        os.environ["SECRET_KEY"],
        algorithm=os.environ["ALGORITHM"]
    )

    refresh_token = jwt.encode(
        refresh_token_payload,
        os.environ["SECRET_KEY"],
        algorithm=os.environ["ALGORITHM"]
    )

    return {
        "access_token": access_token,
        "refresh_token": refresh_token,
        "token_type": "bearer"
    }

# tokenを使用して現在のUserを取得
def get_current_user_from_token(session, token: str, token_type: str):
    payload = jwt.decode(
        token,
        os.environ["SECRET_KEY"],
        algorithms=[os.environ["ALGORITHM"]]
    )

    current_user = get_user_by_uuid(session, payload["uuid"])

    if payload["token_type"] != token_type:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="token_type does not match."
        )
    if token_type == "refresh_token" and current_user.refresh_token != token:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="refresh_token does not match."
        )
    return current_user

# 現在ログイン中のUserを取得
def get_current_user(token: str = Depends(settings.oauth2_schema), session: Session = Depends(get_session)):
    print(get_current_user_from_token(session, token, "access_token"))
    return get_current_user_from_token(session, token, "access_token")

# refresh_tokenからUserを取得
def get_user_from_refresh_token(token: str = Depends(settings.oauth2_schema), session: Session = Depends(get_session)):
    return get_current_user_from_token(session, token, "refresh_token")

# tokenの検証処理
def verify_token(token: str):
    try:
        payload = jwt.decode(
            token,
            os.environ["SECRET_KEY"],
            algorithms=[os.environ["ALGORITHM"]]
        )
        uuid: str = payload.get("uuid")
        if uuid is None:
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail = "Unauthorized."
            )
        token_data = TokenData(uuid = uuid)
        return token_data
    except JWTError:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail = "Invalid token."
        )

CRUD処理のAPI

これまで作成したSchemaやロジックを組み合わせてAPIを構築します。
基本的なGET,POST,PUT,DELETEの各メソッドを作成しています。

コードはこちら
from fastapi import APIRouter, Depends, HTTPException, status, UploadFile, File
from sqlmodel import Session, select
from typing import List

import shutil

from v1.crud.auth import create_tokens

# functions
from ..crud.user import (
    check_user,
    update_user,
    delete_user,
    get_hashed_password,
    users_register,
    get_user_by_uuid
)

# schemas
from ..schema.user import (
    User,
    UserCreate,
    UserOut,
    UserRead,
    UserUpdate,
)
from ..schema.user import Token, User

# db
from ..db.db import get_session

# prefixとしてuserを設定=>「Base URL/user/~」となる
router = APIRouter(
    prefix="/user",
    tags=["User"]
)

# sign up
@router.post("/signup", response_model=Token)
def create_user(userIn: UserCreate, session: Session = Depends(get_session)):
    user = check_user(session, userIn)

    if user:
        raise HTTPException(
            status_code=409,
            detail = "Username or email has been used. Please change to other username or email."
        )
    
    else:
        new_user = User(
            username=userIn.username,
            email = userIn.email,
            is_active=True,
            is_superuser=False,
            hashed_password=get_hashed_password(userIn.password),
            refresh_token=""
        )

        session.add(new_user)
        session.commit()
        session.refresh(new_user)

        token = create_tokens(new_user.uuid)
        return token

# Userの一括登録用API
@router.post("/upload-users")
def upload_users(
    session: Session = Depends(get_session),
    file: UploadFile = File(...)
):
    file_location = f"temp_{file.filename}" # 一時的にファイルを保存
    
    with open(file_location, "wb+") as file_object:
        shutil.copyfileobj(file.file, file_object)

    result = users_register(session, file_location)

    if result:
        return {
            "message": "Some users failed to register. Please check each item.",
            "failed_users": result
        }
    else:
        return {"message": "Users registered successfully."}

# すべてのUserを取得
@router.get("/all_user", response_model=List[UserOut])
def read_all_user(session: Session = Depends(get_session)):
    all_user = session.exec(
        select(User)
    ).all()
    return all_user

# Userのuuidを指定してデータを取得
@router.get("/{user_uuid}", response_model=UserRead)
def get_user(user_uuid: str, session: Session = Depends(get_session)):
    user = get_user_by_uuid(session, user_uuid)
    
    if not user:
        raise HTTPException(
            status_code=status.HTTP_404_NOT_FOUND,
            detail="User not found."
        )
    
    return user

# Userの更新処理(Userデータ特定のためにuuidが必要)
@router.patch("/edit_user/{user_uuid}", response_model=UserOut)
def edit_user(
    *,
    session: Session=Depends(get_session),
    user: UserUpdate,
    user_uuid: str
):
    return update_user(session, user_uuid, user)

# Userの削除処理(Userデータ特定のためにuuidが必要)
@router.delete("/delete_user/{user_uuid}")
def user_delete(
    *,
    session: Session=Depends(get_session),
    user_uuid: str
):
    user = get_user_by_uuid(session, user_uuid)

    if not user:
        raise HTTPException(
            status_code=status.HTTP_404_NOT_FOUND,
            detail="User not found."
        )
    return delete_user(session, user_uuid)

https://fastapi.tiangolo.com/ja/tutorial/security/first-steps/

以上がFastAPIでのCRUD処理になります。

Discussion