Pydanticで非プリミティブを使う
非プリミティブな型はPydanticでは使えない
こんにちは極論モンスターのYosematです。
Pydanticはbuilt-inで値のValidationやSerialization(stringやjsonに変換して保存したりする機能)がついていて便利ですが非プリミティブな型はPydanticModelでは使えません。これは型によってValidation/Serializationのロジックは異なるからです。
Pythonユーザーのよくあるユースケースは「numpyをPydanticで使う」だと思います。
import numpy as np
from pydantic import BaseModel
class MyModel(BaseModel): # PydanticSchemaGenerationError: Unable to generate pydantic-core schema for <class 'numpy.ndarray'>.
a: int
b: np.ndarray
instance = MyModel(a=3, b=np.array([1, 2, 3]))
※ちなみにnumpyの配列の型は正しくはndarray。np.arrayではないので注意。
CustomValidatorでPydantic-Friendlyにする
PEP-593で追加されたtyping.Annotatedは既存の型をデコレートするために存在します。これを使って普通の型をPydanticで使える型に変更することができます。
Pydanticで使うための最もシンプルな方法はPlainValidator
を使ってValidationロジックを追加してやることです。
from typing import Annotated, Any
import numpy as np
from pydantic import BaseModel, PlainValidator, ValidationInfo
def validate(v: Any, info: ValidationInfo) -> np.ndarray:
if isinstance(v, np.ndarray):
ans = v
elif isinstance(v, (list, tuple)):
ans = np.array(v)
else:
raise TypeError(
f"Expected numpy.ndarray, list or tuple of float, got {type(v)}"
)
if ans.ndim != 2:
raise ValueError(f"Expected 2D array, got {ans.ndim}D array")
return ans
MyArray = Annotated[
np.ndarray,
PlainValidator(validate),
]
class MyModel(BaseModel):
a: int
b: MyArray
instance = MyModel(a=3, b=np.array([[1, 2, 3]]))
PlainValidator
の引数のvalidator
関数は入力された値が適切な型であるかを検知して、適切ならば目的の型へと変換する役割をもっています。
validatorの戻り値はboolではなく変換後の値なので注意。
この例では二次元の配列へと変換できるnumpy.ndarray, tuple, listのみを受け付けています。
自在にロジックがかけるので機械学習でよくある配列のshapeに関するValidationなどが書けるのも魅力的です。
Custom Serializerで完全にSerializeする
Annotated
とPlainSerializer
を使ったCustom Validationによってnumpy配列をPydantic-Friendlyな型へと変換することができましたが、これだけではSerializationができません。つまりjsonやdictへの変換ができません。
instance = MyModel(a=3, b=np.array([[1, 2, 3]]))
json = instance.model_dump_json() # Unable to serialize unknown type: <class 'numpy.ndarray'>
そこでCustomSerializerを実装しましょう。やり方はCustomValidatorと全く同じです。
def serialize(v: np.ndarray, info: SerializationInfo) -> list[list[float]]:
return v.tolist()
MyArray = Annotated[
np.ndarray,
PlainValidator(validate),
PlainSerializer(serialize),
]
instance = MyModel(a=3, b=np.array([[1, 2, 3]]))
json = instance.model_dump_json()
print(json) # {"a":3,"b":[[1.0,2.0,3.0]]}
@field_validatorや@model_validatorとの違い
実はAnnotatedを使う他にもfield_validatorやmodel_validatorというデコレータベースの手法もあります。Pydantic v1ではこちらが主流だったと思います。なぜかはわかりませんが、こちらの手法の場合model_config
パラメータとしてConfigDict(arbitrary_types_allowed=True)
を渡す必要があります。
from typing import Any
import numpy as np
from pydantic import BaseModel, ConfigDict, field_serializer, field_validator
class MyModel(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
a: int
b: np.ndarray
@field_validator("b")
@classmethod
def construct_2d_array(cls, v: Any) -> np.ndarray:
if isinstance(v, np.ndarray):
ans = v
elif isinstance(v, (list, tuple)):
ans = np.array(v)
else:
raise TypeError(
f"Expected numpy.ndarray, list or tuple of float, got {type(v)}"
)
if ans.ndim != 2:
raise ValueError(f"Expected 2D array, got {ans.ndim}D array")
return ans
@field_serializer("b")
def serialize_2d_array(self, v: np.ndarray) -> list[list[float]]:
return v.tolist()
instance = MyModel(a=3, b=np.array([[1, 2, 3]]))
json = instance.model_dump_json()
print(json) # {"a":3,"b":[[1.0,2.0,3.0]]}
しかし、私はこちらよりもAnnotated
を使うケースが有用の場合が多いと思います。なぜならValidationやSerializationロジックは特定のモデルやフィールドではなく型に紐づいていると思うからです。そしてモデルの実装とValidation/Serializationロジックが分離されるのでモデルの実装がすごくすっきりします。
たとえば先ほどのソースコードからValidation/Serializationロジックを他ファイルへと分割すればMyModelクラスの定義はこんなにすっきりです!
from pydantic import BaseModel
from mypkg import MyArray
class MyModel(BaseModel):
a: int
b: MyArray
またここにかいたように、一度定義したAnnotatedな型MyArray
はソースコードのどこからでもimportして使えます。@field_validator
はあくまでモデルに紐づいているので、他のモデル定義で同じロジックを使うことはできません。
まとめ
私がPydanticをぐっと好きになった理由の1つはこのAnnotatedによって型にValidation/Serializationロジックを組み込めるようになったところです。直感的でとても素敵ですよね!これらの機能はPydanticのWebアプリケーションとしての応用だけでなく、機械学習領域への応用をも可能にしています。Pydantic v2になりTypeHintを使って安全なコードを書く流れはこれまでのスコープのさらに外へと広がっていくはずです。
みなさんも一緒にType-SafeなPythonの開発を楽しんでいきましょう!
コード全体
from typing import Annotated, Any
import numpy as np
from pydantic import (
BaseModel,
PlainSerializer,
PlainValidator,
)
def validate(v: Any) -> np.ndarray:
if isinstance(v, np.ndarray):
ans = v
elif isinstance(v, (list, tuple)):
ans = np.array(v)
else:
raise TypeError(
f"Expected numpy.ndarray, list or tuple of float, got {type(v)}"
)
if ans.ndim != 2:
raise ValueError(f"Expected 2D array, got {ans.ndim}D array")
return ans
def serialize(v: np.ndarray) -> list[list[float]]:
return v.tolist()
MyArray = Annotated[
np.ndarray,
PlainValidator(validate),
PlainSerializer(serialize),
]
class MyModel(BaseModel):
a: int
b: MyArray
instance = MyModel(a=3, b=np.array([[1, 2, 3]]))
json = instance.model_dump_json()
print(json) # {"a":3,"b":[[1.0,2.0,3.0]]}
Discussion