FastAPI で CURD する。
FastAPI の使い方について学習してみました。とてもシンプルなフレームワークで、公式のチュートリアル も分かりやすかったのですが、最小限の機能を提供しているがゆえに、RDBMS へのアクセスまでやろうとすると、ORM は SQLAlchemy、マイグレーションは Alembic と、その辺の前提知識がないと理解しにくい部分があったので、備忘としてメモをしておきます。
基本的にはチュートリアルの「 SQL (Relational) Databases 」の写経ですが、同じく、「Bigger Applications - Multiple Files 」を参考に、ファイル分割などをしています。
事前準備
PostgreSQL の構築
今回は PostgreSQL を使用します。docker compose で起動することにします。
version: '3'
services:
postgres:
image: postgres:15
volumes:
- ./postgres/data:/var/lib/postgresql/data
- ./postgres/initdb:/docker-entrypoint-initdb.d
ports:
- "${POSTGRES_PORT}:5432"
environment:
- TZ=Asia/Tokyo
- POSTGRES_DB=${POSTGRES_DB}
- POSTGRES_USER=${POSTGRES_USER}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD}
volumes:
postgres:
driver: local
.env
に設定情報を記載して起動します。
docker compose up -d
Python の仮想環境の構築
Python の仮想環境を構築します。今回は pipenv
を使いましたが、pip
や Docker でもよいと思います。
pipenv --python 3.10
pipenv shell
パッケージのインストール
必要となるパッケージをインストールします。
pipenv install "fastapi[all]"
pipenv install sqlalchemy
pipenv install psycopg2-binary
pipenv install --dev alembic
FastAPI アプリの構築
最終的なディレクトリ構成を示すと、以下のようになります。
../practice-fastapi/backend
├── Pipfile
├── Pipfile.lock
├── alembic.ini
├── app
│ ├── __init__.py
│ ├── config.py
│ ├── crud
│ │ ├── __init__.py
│ │ ├── item.py
│ │ └── user.py
│ ├── database.py
│ ├── dependencies.py
│ ├── main.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── item.py
│ │ └── user.py
│ ├── routers
│ │ ├── __init__.py
│ │ ├── item.py
│ │ └── user.py
│ └── schemas
│ ├── __init__.py
│ ├── item.py
│ └── user.py
└── migrations
├── README
├── env.py
├── script.py.mako
└── versions
最初に app
というディレクトリを作成し、その中に FastAPI アプリを作成していきます。環境変数に POSTGRES_HOST
や POSTGRES_PORT
が設定されていると、その値が設定されます。
.env
から読み込む方法もありますが、コンテナ化するときに、イメージに環境による差異を含まない方がうれしいので、環境変数にしました。
mkdir app
cd app
コンフィグレーション
PostgreSQL への接続情報など、設定情報を取り込みます。
from typing import Any, Dict, Optional
from pydantic import BaseSettings, PostgresDsn, validator
class Settings(BaseSettings):
POSTGRES_HOST: str
POSTGRES_PORT: str
POSTGRES_DB: str
POSTGRES_USER: str
POSTGRES_PASSWORD: str
SQLALCHEMY_DATABASE_URI: Optional[PostgresDsn] = None
@validator("SQLALCHEMY_DATABASE_URI", pre=True)
def assemble_db_connection(cls, v: Optional[str], values: Dict[str, Any]) -> Any:
if isinstance(v, str):
return v
return PostgresDsn.build(
scheme="postgresql",
user=values.get("POSTGRES_USER"),
password=values.get("POSTGRES_PASSWORD"),
host=values.get("POSTGRES_HOST"),
path=f"/{values.get('POSTGRES_DB') or ''}",
)
class Config:
case_sensitive = True
settings = Settings()
データベースとの連携
SQLAlchemy の設定を行います。
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from .config import settings
engine = create_engine(settings.SQLALCHEMY_DATABASE_URI)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
from typing import Generator
from app.database import SessionLocal
def get_db() -> Generator:
db = SessionLocal()
try:
yield db
finally:
db.close()
データベース・モデルの作成
データベースのモデルを定義します。これが ORM により、データベースのテーブルになります。新たにテーブルを追加する場合には、ここにファイルを追加していきます。
最初に Base モデルを作成し、それぞれのモデルは Base を継承します。
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
User や Item のカラムを作成していきます。
from sqlalchemy import Boolean, Column, Integer, String
from sqlalchemy.orm import relationship
from .base import Base
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True, index=True)
email = Column(String, unique=True, index=True)
hashed_password = Column(String)
is_active = Column(Boolean, default=True)
items = relationship("Item", back_populates="owner")
from sqlalchemy import Column, ForeignKey, Integer, String
from sqlalchemy.orm import relationship
from .base import Base
class Item(Base):
__tablename__ = "items"
id = Column(Integer, primary_key=True, index=True)
title = Column(String, index=True)
description = Column(String, index=True)
owner_id = Column(Integer, ForeignKey("users.id"))
owner = relationship("User", back_populates="items")
from .base import Base # noqa: F401
from .item import Item # noqa: F401
from .user import User # noqa: F401
スキーマ (Pydantic model) の作成
Pydantic のモデル(スキーマ)として、データを作成したり読み込むとき使用使用される共通の属性を定義します。
例えば、User を作成するときは、API からは email
と password
だけが渡され、id
は自動採番されるので、 UserCreate
には email
と password
だけが含まれます。同様に、 User を読み込む場合には、セキュリティ上の理由から password
は含まれません。
先ほど作成したモデルとスキーマとの違いが分かりにくく混乱するのですが、モデルは SQLAlchemy 、スキーマは Pydantic のための定義となります。
from pydantic import BaseModel
from .item import Item
class UserBase(BaseModel):
email: str
class UserCreate(UserBase):
password: str
class User(UserBase):
id: int
is_active: bool
items: list[Item] = []
class Config:
orm_mode = True
from typing import Union
from pydantic import BaseModel
class ItemBase(BaseModel):
title: str
description: Union[str, None] = None
class ItemCreate(ItemBase):
pass
class Item(ItemBase):
id: int
owner_id: int
class Config:
orm_mode = True
from .item import Item, ItemCreate # noqa: F401
from .user import User, UserCreate # noqa: F401
CRUD 関数の作成
REST API が実行されたときに、データベース操作をするための関数群を作成していきます。
from typing import List
from sqlalchemy.orm import Session
from app import models, schemas
def get_user(db: Session, user_id: int) -> models.User:
return db.query(models.User).filter(models.User.id == user_id).first()
def get_user_by_email(db: Session, email: str) -> models.User:
return db.query(models.User).filter(models.User.email == email).first()
def get_users(db: Session, skip: int = 0, limit: int = 100) -> List[models.User]:
return db.query(models.User).offset(skip).limit(limit).all()
def create_user(db: Session, user: schemas.UserCreate) -> models.User:
fake_hashed_password = user.password + "notreallyhashed"
db_user = models.User(email=user.email, hashed_password=fake_hashed_password)
db.add(db_user)
db.commit()
db.refresh(db_user)
return db_user
from typing import List
from sqlalchemy.orm import Session
from app import models, schemas
def get_items(db: Session, skip: int = 0, limit: int = 100) -> List[models.Item]:
return db.query(models.Item).offset(skip).limit(limit).all()
def create_user_item(db: Session, item: schemas.ItemCreate, user_id: int) -> models.Item:
db_item = models.Item(**item.dict(), owner_id=user_id)
db.add(db_item)
db.commit()
db.refresh(db_item)
return db_item
from .item import * # noqa: F401, F403
from .user import * # noqa: F401, F403
ルーターの構築
APIRouter を使用してモジュールのパス操作を作成します。
from typing import Any
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from app import crud, dependencies, schemas
router = APIRouter()
@router.post("/", response_model=schemas.User)
def create_user(
user: schemas.UserCreate,
db: Session = Depends(dependencies.get_db)
) -> Any:
db_user = crud.get_user_by_email(db, email=user.email)
if db_user:
raise HTTPException(status_code=400, detail="Email already registered")
return crud.create_user(db=db, user=user)
@router.get("/", response_model=list[schemas.User])
def read_users(
skip: int = 0,
limit: int = 100,
db: Session = Depends(dependencies.get_db)
) -> Any:
users = crud.get_users(db, skip=skip, limit=limit)
return users
@router.get("/{user_id}", response_model=schemas.User)
def read_user(
user_id: int,
db: Session = Depends(dependencies.get_db)
) -> Any:
db_user = crud.get_user(db, user_id=user_id)
if db_user is None:
raise HTTPException(status_code=404, detail="User not found")
return db_user
@router.post("/{user_id}/items/", response_model=schemas.Item)
def create_item_for_user(
user_id: int,
item: schemas.ItemCreate,
db: Session = Depends(dependencies.get_db)
) -> Any:
return crud.create_user_item(db=db, item=item, user_id=user_id)
from typing import Any
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app import crud, dependencies, schemas
router = APIRouter()
@router.get("/", response_model=list[schemas.Item])
def read_items(
skip: int = 0,
limit: int = 100,
db: Session = Depends(dependencies.get_db)
) -> Any:
items = crud.get_items(db, skip=skip, limit=limit)
return items
from fastapi import APIRouter
from . import item, user
api_router = APIRouter()
api_router.include_router(user.router, prefix="/users", tags=["users"])
api_router.include_router(item.router, prefix="/items", tags=["items"])
メイン処理
メイン処理を実装します。
from fastapi import FastAPI
from app import models
from app.database import engine
from app.routers import api_router
# models.Base.metadata.create_all(bind=engine)
app = FastAPI()
app.include_router(api_router, prefix="/api")
DBマイグレーション
Alembic と使ってデータベースのマイグレーションを行います。
最初に Alembic をインストールします。
pipenv install alembic --dev
Alembic の初期化を行います。
alembic init migrations
以下のようなディレクトリ・ファイルが作成されます。
├── alembic.ini
├── migrations
│ ├── README
│ ├── env.py
│ ├── script.py.mako
│ └── versions
alembic.ini
の sqlalchemy.url
を実際のDBの設定に書き換えます。
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
#sqlalchemy.url = driver://user:pass@localhost/dbname
sqlalchemy.url = postgresql://postgres:changeme@127.0.0.1:5432/dbname
migrations/env.py
の target_metadata
に models/base.py
で定義した Base
クラスを指定します。
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
from app.models import Base # noqa
target_metadata = Base.metadata
マイグレーションファイルを自動生成します。
$ alembic revision --autogenerate -m "create tables"
INFO [alembic.runtime.migration] Context impl PostgresqlImpl.
INFO [alembic.runtime.migration] Will assume transactional DDL.
INFO [alembic.autogenerate.compare] Detected added table 'users'
INFO [alembic.autogenerate.compare] Detected added index 'ix_users_email' on '['email']'
INFO [alembic.autogenerate.compare] Detected added index 'ix_users_id' on '['id']'
INFO [alembic.autogenerate.compare] Detected added table 'items'
INFO [alembic.autogenerate.compare] Detected added index 'ix_items_description' on '['description']'
INFO [alembic.autogenerate.compare] Detected added index 'ix_items_id' on '['id']'
INFO [alembic.autogenerate.compare] Detected added index 'ix_items_title' on '['title']'
Generating /home/shasegawa/works/20230507/migrations/versions/ffa4a4a17ff7_create_tables.py ... done
migrations/versions
以下にマイグレーションファイルが作成されます。完全に差分が検出できるわけではないようなので、手動でのファイル修正が必要なこともあるようです。
"""create tables
Revision ID: ffa4a4a17ff7
Revises:
Create Date: 2023-05-07 16:57:01.644836
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'ffa4a4a17ff7'
down_revision = None
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('users',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('email', sa.String(), nullable=True),
sa.Column('hashed_password', sa.String(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
op.create_index(op.f('ix_users_id'), 'users', ['id'], unique=False)
op.create_table('items',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('title', sa.String(), nullable=True),
sa.Column('description', sa.String(), nullable=True),
sa.Column('owner_id', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['owner_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_items_description'), 'items', ['description'], unique=False)
op.create_index(op.f('ix_items_id'), 'items', ['id'], unique=False)
op.create_index(op.f('ix_items_title'), 'items', ['title'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_items_title'), table_name='items')
op.drop_index(op.f('ix_items_id'), table_name='items')
op.drop_index(op.f('ix_items_description'), table_name='items')
op.drop_table('items')
op.drop_index(op.f('ix_users_id'), table_name='users')
op.drop_index(op.f('ix_users_email'), table_name='users')
op.drop_table('users')
# ### end Alembic commands ###
マイグレーションをデータベースに反映させます。
$ alembic upgrade head
INFO [alembic.runtime.migration] Context impl PostgresqlImpl.
INFO [alembic.runtime.migration] Will assume transactional DDL.
INFO [alembic.runtime.migration] Running upgrade -> ffa4a4a17ff7, create tables
データベースに接続して確認すると、3つのテーブルが作成されたことが確認できます。
alembic_version
に適用されたマイグレーションのハッシュ値が格納されて行きます。
アプリの起動
環境変数を設定し、uvicorn でアプリを起動します。
bash -c 'export $(cat ../.env | grep -v ^#) && uvicorn app.main:app --reload'
INFO: Will watch for changes in these directories: ['****']
INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)
INFO: Started reloader process [28991] using WatchFiles
INFO: Started server process [28999]
INFO: Waiting for application startup.
INFO: Application startup complete.
http://127.0.0.1:8001/docs にアクセスし、REST API の実行を行います。
Discussion