Pydanticで始めるPythonのバリデーションとシリアライゼーション
はじめに
Pydanticを使用することで、Pythonコードでのデータバリデーションとデータシリアライゼーションを簡単かつ効率的に行うことができます。
この記事では、Pydanticの基本的な使い方から、より高度なバリデーションとシリアライゼーションまで幅広く紹介します。また、簡易的なものですが他のバリデーションライブラリとの速度比較も行っています。
Pydanticとは
Pydanticは、Pythonのバリデーションライブラリです。以下のような特徴を持ちます。
- 型アノテーションをつけるだけでバリデーションとシリアライゼーションを実現できる
- 独自のバリデーションやシリアライゼーションを柔軟に定義することができる
- Pydantic V2はコアロジックがRustで実装されていて高速に動作する
dataclasses+jsonと比較
dataclasses+jsonを使ったコードとPydanticを使ったコードを比較してみましょう。
まず、dataclassesとjsonを使用して、バリデーションとJSON文字列への変換を実装します。
以下のコードでは、次のようなバリデーションを実装しています。
-
name
が4~16文字のstr型であるか -
age
が18~99のint型であるか -
date
がdate型であるか
import json
from dataclasses import asdict, dataclass
from datetime import date
from typing import Any
class DateEncoder(json.JSONEncoder):
def default(self, obj: Any) -> Any:
if isinstance(obj, date):
return obj.strftime("%Y-%m-%d") # date型の場合は%Y-%m-%d形式の文字列に変換する
return super().default(obj)
@dataclass
class User:
name: str
age: int
birthday: date
def __post_init__(self):
if not isinstance(self.name, str):
raise TypeError("name must be str")
if len(self.name) < 4 or len(self.name) > 16:
raise ValueError("name must be between 4 and 16 chars")
if not isinstance(self.age, int):
raise TypeError("age must be int")
if self.age < 18 or self.age > 99:
raise ValueError("age must be between 18 and 99")
if not isinstance(self.birthday, date):
raise TypeError("birthday must be date")
def to_json(self) -> str:
return json.dumps(asdict(self), cls=DateEncoder, indent=2)
if __name__ == "__main__":
user = User(name="John", age=18, birthday=date(2000, 1, 1))
print(user.to_json())
to_jsonの結果
{
"name": "John",
"age": 18,
"birthday": "2000-01-01"
}
次にPydanticを使ったコードです。dataclasses+jsonのコードと同じことをしていますが、シンプルな実装になっているのがわかると思います。
from datetime import date
from pydantic import BaseModel, Field
class User(BaseModel):
name: str = Field(..., min_length=4, max_length=16) # 4~16文字のstr型
age: int = Field(..., ge=18, le=99) # 18~99のint型
birthday: date
if __name__ == "__main__":
user = User(name="John", age=18, birthday=date(2000, 1, 1))
print(user.model_dump_json(indent=2))
バリデーション
基本的なバリデーション
基本的なバリデーションには、型アノテーションと Field
を使用します。
型アノテーションによって型がチェックされ、Field
によって数値の範囲や文字列の長さなど簡単な制限をつけることができます。
from datetime import date
from pydantic import BaseModel, Field
class User(BaseModel):
name: str = Field(..., min_length=4, max_length=16) # 4~16文字のstr型
age: int = Field(..., ge=18, le=99) # 18~99のint型
birthday: date
カスタムバリデーション
フィールドごとのカスタムバリデーションを定義するには field_validator()
を使います。
対象となるフィールド名を field_validator()
に渡し、クラスメソッドとしてバリデーションロジックを定義します。
以下のコードでは、name
がアルファベットもしくは数字のみで構成されている文字列であるかをvalidate_alphanumeric()
メソッドで確認しています。
from datetime import date
from pydantic import BaseModel, Field, field_validator
class User(BaseModel):
name: str = Field(..., min_length=4, max_length=16)
age: int = Field(..., ge=18, le=99)
birthday: date
@field_validator("name")
@classmethod
def validate_alphanumeric(cls, v: str) -> str:
"""アルファベットもしくは数字のみで構成された文字列であるかチェックする"""
if not v.isalnum():
raise ValueError("must be alphanumeric")
return v
モデル全体へのカスタムバリデーションを定義するには model_validator()
を使います。
model_validator()
は以下の2つのモードを選ぶことができます。
- before: インスタンス生成前に入力のdictをバリデーション
- after: インスタンス生成後にインスタンスをバリデーション
以下のコードでは、入力のdictの中にトークンが存在するかと2つのパスワードが一致しているかを確認しています。
from typing import Self
from pydantic import BaseModel, ValidationError, model_validator
class User(BaseModel):
name: str
password1: str
password2: str
@model_validator(mode="before")
@classmethod
def validate_secret(cls, d: dict) -> dict:
"""入力のdictの中にトークンが存在するかをチェックする"""
if "token" not in d:
raise ValueError("token is required")
return d
@model_validator(mode="after")
def validate_passwords(self) -> Self:
"""2つのパスワードが一致しているかをチェックする"""
if self.password1 != self.password2:
raise ValueError("passwords do not match")
return self
シリアライゼーション
PythonオブジェクトをJSON文字列などの他のデータ形式に変換することをシリアライゼーションと言います。Pydanticはシンプルにシリアライゼーションができるようになっています。
基本的なシリアライゼーション
BaseModel
から継承した model_dump_json()
メソッドを呼ぶことで、オブジェクトのシリアライゼーションを行うことができます。ネストしたモデルは再帰的に変換されます。
以下のコードでは、Task
インスタンスを複数持つ User
インスタンスをJSON文字列に変換しています。
from datetime import date
from pydantic import BaseModel
class Task(BaseModel):
name: str
due_date: date
class User(BaseModel):
name: str
tasks: list[Task]
if __name__ == "__main__":
user = User(
name="John",
tasks=[
Task(name="task1", due_date=date(2023, 10, 26)),
Task(name="task2", due_date=date(2023, 10, 27)),
],
)
print(user.model_dump_json(indent=2))
model_dump_jsonの結果
{
"name": "John",
"tasks": [
{
"name": "task1",
"due_date": "2023-10-26"
},
{
"name": "task2",
"due_date": "2023-10-27"
}
]
}
カスタムシリアライゼーション
フィールドごとのカスタムシリアライゼーションを定義するには field_serializer()
を使います。対象となるフィールド名を field_serializer()
に渡し、シリアライゼーションロジックを定義します。
以下のコードでは、serialize_date()
メソッドにより due_date
を "October 26 2023"
のような形式に変換しています。
from datetime import date
from pydantic import BaseModel, field_serializer
class Task(BaseModel):
name: str
due_date: date
@field_serializer("due_date")
def serialize_date(self, d: date) -> str:
return d.strftime("%B %d %Y")
if __name__ == "__main__":
task = Task(name="task1", due_date=date(2023, 10, 26))
print(task.model_dump_json(indent=2))
model_dump_jsonの結果
{
"name": "task1",
"due_date": "October 26 2023"
}
モデル全体のシリアライゼーションを定義するには model_serializer()
を使います。
以下のコードでは、due_date
を year, month, day
にわけています。
from datetime import date
from typing import Any
from pydantic import BaseModel, model_serializer
class Task(BaseModel):
name: str
due_date: date
@model_serializer
def serialize_date(self) -> dict[str, Any]:
return {
"name": self.name,
"year": self.due_date.year,
"month": self.due_date.month,
"day": self.due_date.day,
}
if __name__ == "__main__":
task = Task(name="task1", due_date=date(2023, 10, 26))
print(task.model_dump_json(indent=2))
model_dump_jsonの結果
{
"name": "task1",
"year": 2023,
"month": 10,
"day": 26
}
非プリミティブな型のバリデーションとシリアライゼーション
Pydanticは、非プリミティブな型を直接型アノテーションとして使用することはできません。以下のコードは PydanticSchemaGenerationError
が発生します。
import numpy as np
from pydantic import BaseModel
class Image(BaseModel):
x: np.ndarray # NG
if __name__ == "__main__":
image = Image(x=np.array([[1, 2], [3, 4]])) # PydanticSchemaGenerationError
非プリミティブな型をPydanticで扱えるようにするために、Annotated
と PlainValidator
を使用してカスタム型を定義します。また、PlainSerializer
を使用して numpy.ndarray
のシリアライゼーションロジックを定義します。
以下のコードは型が numpy.ndarray
になっているかバリデーションし、numpy.ndarray
をリストにシリアライゼーションしています。
from typing import Any
import numpy as np
from pydantic import BaseModel, PlainSerializer, PlainValidator
from typing_extensions import Annotated
def validate_ndarray(x: np.ndarray) -> np.ndarray:
"""numpy.ndarray型かチェックする"""
if not isinstance(x, np.ndarray):
raise TypeError("numpy.ndarray required")
return x
def serialize_ndarray(x: np.ndarray) -> list[Any]:
"""リストに変換する"""
return x.tolist()
# バリデーションとシリアライゼーションの方法を実装したカスタム型を定義
NdArray = Annotated[
np.ndarray,
PlainValidator(validate_ndarray),
PlainSerializer(serialize_ndarray),
]
class Image(BaseModel):
x: NdArray # カスタム型を型アノテーションとして使う
if __name__ == "__main__":
image = Image(x=np.arange(4))
print(image.model_dump_json(indent=2))
model_dump_jsonの結果
{
"x": [
0,
1,
2,
3
]
}
型変換
Pydanticは型アノテーションにより型をチェックし、型が一致しない場合でも暗黙的に型変換を行います。例えば UUID
や datetime
を期待するフィールドにその形式の文字列を渡すと変換してくれます。型変換ができない場合はバリデーションエラーが発生します。
以下のコードでは、UUID, datetime
を期待するフィールドに文字列を渡していますが、適切に変換されます。
from datetime import datetime
from uuid import UUID
from pydantic import BaseModel
class Event(BaseModel):
id: UUID
dt: datetime
if __name__ == "__main__":
# モデル定義の型に変換される
event = Event(
id="19a63218-cca8-489e-89c6-b283a9ac4118", # UUID形式の文字列
dt="2021-08-01T00:00:00+09:00", # datetime形式の文字列
)
print(event)
printした結果
id=UUID('19a63218-cca8-489e-89c6-b283a9ac4118') dt=datetime.datetime(2021, 8, 1, 0, 0, tzinfo=TzInfo(+09:00))
暗黙的な型変換を許可したくない場合はstrictモードを使うことができます。strictモードにすると型変換は行われず、型の不一致によるバリデーションエラーが発生します。モデル全体をstrictモードにするには ConfigDict
を使います。
from datetime import datetime
from uuid import UUID
from pydantic import BaseModel, ConfigDict
class Event(BaseModel):
model_config = ConfigDict(strict=True) # モデル全体でstrcitモードにする
id: UUID
dt: datetime
if __name__ == "__main__":
# 型が一致しないのでValidationErrorが発生する
event = Event(
id="19a63218-cca8-489e-89c6-b283a9ac4118",
dt="2021-08-01T00:00:00+09:00",
)
エラー内容
pydantic_core._pydantic_core.ValidationError: 2 validation errors for Event
id
Input should be an instance of UUID [type=is_instance_of, input_value='19a63218-cca8-489e-89c6-b283a9ac4118', input_type=str]
For further information visit https://errors.pydantic.dev/2.4/v/is_instance_of
dt
Input should be a valid datetime [type=datetime_type, input_value='2021-08-01T00:00:00+09:00', input_type=str]
For further information visit https://errors.pydantic.dev/2.4/v/datetime_type
また、Pydanticの StrictInt
のような型を使用するか、Field
を使用することでフィールドごとにstrictモードにすることができます。
from pydantic import BaseModel, Field, StrictInt
class User(BaseModel):
name: str = Field(..., strict=True) # strictなstr
age: StrictInt # strictなint
パフォーマンス比較
Pydantic, marshmallow, Cerberusのバリデーション速度を比較してみましょう。
以下のバリデーションで検証します。
-
Task
モデル-
name
は4~16文字のstr型であるか -
due_date
はdate型であるか
-
-
User
モデル-
name
は4~16文字のstr型であるか -
age
は18~99のint型であるか -
tasks
はTask
モデルのリストであるか
-
検証環境
- プロセッサ: 2.4GHz 8コア Intel Core i9
- メモリ: 32GB DDR4 2667MHz
- オペレーティングシステム: macOS Ventura 13.5.2
以下のテーブルに1000回の合計実行時間をまとめました。3つのライブラリの中だとPydancitが最も速いのがわかります。
Library | Time (1000 loops) |
---|---|
Pydantic 2.4.2 | 0.0043 sec |
marshmallow 3.20.1 | 0.1646 sec |
Cerberus 1.3.5 | 0.5930 sec |
Pydanticのバリデーション速度計測コード
import timeit
from datetime import date
from pydantic import BaseModel, Field
class Task(BaseModel):
name: str = Field(..., min_length=4, max_length=16)
due_date: date
class User(BaseModel):
name: str = Field(..., min_length=4, max_length=16)
age: int = Field(..., ge=18, le=99)
tasks: list[Task]
if __name__ == "__main__":
user_dict = {
"name": "John",
"age": 20,
"tasks": [
{"name": "Task 1", "due_date": "2021-01-01"},
{"name": "Task 2", "due_date": "2021-02-02"},
],
}
print(f"Pydantic: {timeit.timeit(lambda: User.model_validate(user_dict), number=1000)}")
marshmallowのバリデーション速度計測コード
import timeit
from marshmallow import Schema, fields, validate
class TaskSchema(Schema):
name = fields.Str(validate=validate.Length(min=4, max=16))
due_date = fields.Date()
class UserSchema(Schema):
name = fields.Str(validate=validate.Length(min=4, max=16))
age = fields.Int(validate=validate.Range(min=18, max=99))
tasks = fields.List(fields.Nested(TaskSchema))
if __name__ == "__main__":
user_dict = {
"name": "John",
"age": 20,
"tasks": [
{"name": "Task 1", "due_date": "2021-01-01"},
{"name": "Task 2", "due_date": "2021-02-02"},
],
}
print(f"marshmallow: {timeit.timeit(lambda: UserSchema().load(user_dict), number=1000)}")
Cerberusのバリデーション速度計測コード
import timeit
from cerberus import Validator
validator = Validator(
{
"name": {"type": "string", "minlength": 4, "maxlength": 16},
"age": {"type": "integer", "min": 18, "max": 99},
"tasks": {
"type": "list",
"schema": {
"type": "dict",
"schema": {"name": {"type": "string", "minlength": 4, "maxlength": 16}, "due_date": {"type": "date"}},
},
},
}
)
if __name__ == "__main__":
user_dict = {
"name": "John",
"age": 20,
"tasks": [
{"name": "Task 1", "due_date": "2021-01-01"},
{"name": "Task 2", "due_date": "2021-02-02"},
],
}
print(f"Cerberus: {timeit.timeit(lambda: validator.validate(user_dict), number=1000)}")
リファレンス
Discussion