🏎️
FastAPIでCRUD処理
Schema定義
userデータにおけるCRUD処理に必要なschemaを定義します。
コードはこちら
from sqlmodel import Relationship, SQLModel, Field, AutoString
from typing import Optional, TYPE_CHECKING
from pydantic import EmailStr
import uuid
from datetime import datetime
from ..schema.press import PressReleaseReadWithUser
if TYPE_CHECKING:
from .press import PressRelease
class BaseUser(SQLModel):
username: str = Field(nullable = False, default = None, unique = True)
email: EmailStr =Field(unique = True, index = True, sa_type=AutoString)
is_active: bool = Field(default = True)
is_superuser: bool = Field(default = False)
class User(BaseUser, table = True):
id: Optional[int] = Field(default = None, primary_key = True)
uuid: str = Field(default_factory = uuid.uuid4, nullable = False)
hashed_password: str = Field(nullable = False)
created_at: datetime = Field(default = datetime.now(), nullable = False)
updated_at: datetime = Field(default_factory = datetime.now, nullable = False, sa_column_kwargs = {'onupdate': datetime.now})
refresh_token: str = Field(nullable = True)
press_releases: list["PressRelease"] | None = Relationship(back_populates="user", cascade_delete=True)
class UserOut(BaseUser):
id: int
uuid: str
created_at: datetime
updated_at: datetime
refresh_token: str
class UserLogin(SQLModel):
email: str
password: str
class UserCreate(SQLModel):
username: str
email: EmailStr
password: str
class UserRead(BaseUser):
id: int
uuid: str
created_at: datetime
updated_at: datetime
press_releases: list["PressReleaseReadWithUser"] = []
class UserUpdate(BaseUser):
username: Optional[str] = None
email: Optional[EmailStr] = None
is_active: Optional[bool]
is_superuser: Optional[bool]
【ログイン処理】
- ログイン処理(UserLogin)は、emailとpasswordで行います
- サインアップ処理(UserCreate)は、email、passwordに加えて、usernameを必要としています
【データの呼び出し】
- BaseUserを承継し、登録後に設定されるid、uuid、created_atなどの項目を取得します
【更新処理】
- 更新する値は、primary_keyにしているidと個別のidとして割り当てているuuidを除く項目にしています
CRUD処理に必要なロジック
CRUD処理のロジックをまとめています。
まずは、Userに関するロジックです。
コードはこちら
import os
from fastapi import HTTPException, status
from typing import List
# schema
from ..schema.user import User, UserOut
# sqlmodel
from sqlmodel import select
# uuid
import uuid as uuid_pkg
# settings
from ..core.config import settings
# pandas
import pandas as pd
# Userの登録確認用
def check_user(session, user):
checked_user = session.exec(
select(User)
.where(User.username == user.username)
.where(User.email == user.email)
).first()
return checked_user
# hash化されたパスワード生成用
def get_hashed_password(password):
return settings.pwd_context.hash(password)
# パスワードチェック
def verfy_password(password, hashed_password):
return settings.pwd_context.verify(password, hashed_password)
# すべてのUserを取得する
def get_all_user(session):
all_user = select(User)
user_list: List[UserOut] = session.exec(all_user)
# Userの登録情報を更新
def update_user(session, uuid, user):
db_user = session.exec(select(User).where(User.uuid == uuid)).first()
if not db_user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail = "User not found."
)
user_data = user.model_dump(exclude_unset = True)
for key, value in user_data.items():
setattr(db_user, key, value)
session.add(db_user)
session.commit()
session.refresh(db_user)
return db_user
# Userの削除(uuidでUserを呼び出し)
def delete_user(session, uuid):
user = get_user_by_uuid(session, uuid)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail = "User not found."
)
session.delete(user)
session.commit()
return {"message": "User was deleted."}
# uuidでUserを取得
def get_user_by_uuid(session, uuid: str):
return session.query(User).filter(User.uuid == uuid).first()
# emailでUserを取得
def get_user_by_email(session, email: str):
return session.query(User).filter(User.email == email).first()
# 一括登録用のメソット(pandasを使用)
def users_register(session, file_path: str):
failed_users = []
if file_path.endswith(".csv"):
df = pd.read_csv(file_path)
elif file_path.endswith(".xlsx"):
df = pd.read_excel(file_path)
else:
raise ValueError(
"Invalid file format. Please provide a CSV or EXCEL file."
)
for index, row in df.iterrows():
# エクセルのA列: username, B列: email, C列:password, D列:is_active, E列: is_superuser
username = row["username"]
email = row["email"]
password = row["password"]
is_active = row["is_active"]
is_superuser = row["is_superuser"]
if pd.isna(username) or pd.isna(email) or pd.isna(password):
failed_users.append({
"username": username,
"error": "Username, email, password are required fired."
})
continue
try:
hashed_password = get_hashed_password(password)
# 既存ユーザーの検索
existing_user = session.query(User).filter((User.username == username) | (User.email == email)).first()
if existing_user:
# ユーザーが存在する場合、情報を更新(空欄の場合はその項目をスキップ)
if not pd.isna(username):
existing_user.username = username
if not pd.isna(email):
existing_user.email = email
if not pd.isna(password):
existing_user.hashed_password = hashed_password
if not pd.isna(is_active):
existing_user.is_active = is_active
if not pd.isna(is_superuser):
existing_user.is_superuser = is_superuser
else:
# ユーザーが存在しない場合、新規作成
new_user = User(
uuid = str(uuid_pkg.uuid4()),
username = username,
email = email,
hashed_password=hashed_password,
is_active=is_active,
is_superuser=is_superuser,
)
session.add(new_user)
except Exception as e:
failed_users.append({
"username": username,
"error": str(e)
})
session.commit()
return failed_users
次に、認証に関するロジックです。
コードはこちら
import os
from fastapi import Depends, HTTPException, status
from datetime import datetime, timedelta, timezone
# auth
from jose import jwt, JWTError
# sqlmodel
from sqlmodel import Session
# db
from ..db.db import get_session
# settings
from ..core.config import settings
# functions
from .user import verfy_password, get_user_by_uuid, get_user_by_email, get_all_user
# schemas
from ..schema.user import TokenData
# .env
from dotenv import load_dotenv
load_dotenv()
# emailでの認証
def authenticate(session, email, password):
user = get_user_by_email(session, email)
if verfy_password(password, user.hashed_password) != True:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Can't authentiate by your password."
)
return user
# tokenの生成
def create_tokens(uuid: str):
access_token_payload = {
"token_type": "access_token",
"exp": datetime.now(timezone.utc) + timedelta(days=int(os.environ["ACCESS_TOKEN_EXPIRE_DAYS"])),
"uuid": uuid,
}
refresh_token_payload = {
"token_type": "refresh_token",
"exp": datetime.now(timezone.utc) + timedelta(days=int(os.environ["REFRESH_TOKEN_EXPIRE_DAYS"])),
"uuid": uuid,
}
# create token
access_token = jwt.encode(
access_token_payload,
os.environ["SECRET_KEY"],
algorithm=os.environ["ALGORITHM"]
)
refresh_token = jwt.encode(
refresh_token_payload,
os.environ["SECRET_KEY"],
algorithm=os.environ["ALGORITHM"]
)
return {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "bearer"
}
# tokenを使用して現在のUserを取得
def get_current_user_from_token(session, token: str, token_type: str):
payload = jwt.decode(
token,
os.environ["SECRET_KEY"],
algorithms=[os.environ["ALGORITHM"]]
)
current_user = get_user_by_uuid(session, payload["uuid"])
if payload["token_type"] != token_type:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="token_type does not match."
)
if token_type == "refresh_token" and current_user.refresh_token != token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="refresh_token does not match."
)
return current_user
# 現在ログイン中のUserを取得
def get_current_user(token: str = Depends(settings.oauth2_schema), session: Session = Depends(get_session)):
print(get_current_user_from_token(session, token, "access_token"))
return get_current_user_from_token(session, token, "access_token")
# refresh_tokenからUserを取得
def get_user_from_refresh_token(token: str = Depends(settings.oauth2_schema), session: Session = Depends(get_session)):
return get_current_user_from_token(session, token, "refresh_token")
# tokenの検証処理
def verify_token(token: str):
try:
payload = jwt.decode(
token,
os.environ["SECRET_KEY"],
algorithms=[os.environ["ALGORITHM"]]
)
uuid: str = payload.get("uuid")
if uuid is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail = "Unauthorized."
)
token_data = TokenData(uuid = uuid)
return token_data
except JWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail = "Invalid token."
)
CRUD処理のAPI
これまで作成したSchemaやロジックを組み合わせてAPIを構築します。
基本的なGET,POST,PUT,DELETEの各メソッドを作成しています。
コードはこちら
from fastapi import APIRouter, Depends, HTTPException, status, UploadFile, File
from sqlmodel import Session, select
from typing import List
import shutil
from v1.crud.auth import create_tokens
# functions
from ..crud.user import (
check_user,
update_user,
delete_user,
get_hashed_password,
users_register,
get_user_by_uuid
)
# schemas
from ..schema.user import (
User,
UserCreate,
UserOut,
UserRead,
UserUpdate,
)
from ..schema.user import Token, User
# db
from ..db.db import get_session
# prefixとしてuserを設定=>「Base URL/user/~」となる
router = APIRouter(
prefix="/user",
tags=["User"]
)
# sign up
@router.post("/signup", response_model=Token)
def create_user(userIn: UserCreate, session: Session = Depends(get_session)):
user = check_user(session, userIn)
if user:
raise HTTPException(
status_code=409,
detail = "Username or email has been used. Please change to other username or email."
)
else:
new_user = User(
username=userIn.username,
email = userIn.email,
is_active=True,
is_superuser=False,
hashed_password=get_hashed_password(userIn.password),
refresh_token=""
)
session.add(new_user)
session.commit()
session.refresh(new_user)
token = create_tokens(new_user.uuid)
return token
# Userの一括登録用API
@router.post("/upload-users")
def upload_users(
session: Session = Depends(get_session),
file: UploadFile = File(...)
):
file_location = f"temp_{file.filename}" # 一時的にファイルを保存
with open(file_location, "wb+") as file_object:
shutil.copyfileobj(file.file, file_object)
result = users_register(session, file_location)
if result:
return {
"message": "Some users failed to register. Please check each item.",
"failed_users": result
}
else:
return {"message": "Users registered successfully."}
# すべてのUserを取得
@router.get("/all_user", response_model=List[UserOut])
def read_all_user(session: Session = Depends(get_session)):
all_user = session.exec(
select(User)
).all()
return all_user
# Userのuuidを指定してデータを取得
@router.get("/{user_uuid}", response_model=UserRead)
def get_user(user_uuid: str, session: Session = Depends(get_session)):
user = get_user_by_uuid(session, user_uuid)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found."
)
return user
# Userの更新処理(Userデータ特定のためにuuidが必要)
@router.patch("/edit_user/{user_uuid}", response_model=UserOut)
def edit_user(
*,
session: Session=Depends(get_session),
user: UserUpdate,
user_uuid: str
):
return update_user(session, user_uuid, user)
# Userの削除処理(Userデータ特定のためにuuidが必要)
@router.delete("/delete_user/{user_uuid}")
def user_delete(
*,
session: Session=Depends(get_session),
user_uuid: str
):
user = get_user_by_uuid(session, user_uuid)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found."
)
return delete_user(session, user_uuid)
以上がFastAPIでのCRUD処理になります。
Discussion