👻

Pydanticで非プリミティブを使う

2023/09/26に公開

非プリミティブな型はPydanticでは使えない

こんにちは極論モンスターのYosematです。

Pydanticはbuilt-inで値のValidationやSerialization(stringやjsonに変換して保存したりする機能)がついていて便利ですが非プリミティブな型はPydanticModelでは使えません。これは型によってValidation/Serializationのロジックは異なるからです。

Pythonユーザーのよくあるユースケースは「numpyをPydanticで使う」だと思います。

import numpy as np
from pydantic import BaseModel

class MyModel(BaseModel):  # PydanticSchemaGenerationError: Unable to generate pydantic-core schema for <class 'numpy.ndarray'>.
    a: int
    b: np.ndarray

instance = MyModel(a=3, b=np.array([1, 2, 3]))

※ちなみにnumpyの配列の型は正しくはndarray。np.arrayではないので注意。

CustomValidatorでPydantic-Friendlyにする

PEP-593で追加されたtyping.Annotatedは既存の型をデコレートするために存在します。これを使って普通の型をPydanticで使える型に変更することができます。

Pydanticで使うための最もシンプルな方法はPlainValidatorを使ってValidationロジックを追加してやることです。

from typing import Annotated, Any

import numpy as np
from pydantic import BaseModel, PlainValidator, ValidationInfo

def validate(v: Any, info: ValidationInfo) -> np.ndarray:
    if isinstance(v, np.ndarray):
        ans = v
    elif isinstance(v, (list, tuple)):
        ans = np.array(v)
    else:
        raise TypeError(
            f"Expected numpy.ndarray, list or tuple of float, got {type(v)}"
        )
    if ans.ndim != 2:
        raise ValueError(f"Expected 2D array, got {ans.ndim}D array")
    return ans

MyArray = Annotated[
    np.ndarray,
    PlainValidator(validate),
]

class MyModel(BaseModel):
    a: int
    b: MyArray

instance = MyModel(a=3, b=np.array([[1, 2, 3]]))

PlainValidatorの引数のvalidator関数は入力された値が適切な型であるかを検知して、適切ならば目的の型へと変換する役割をもっています。
validatorの戻り値はboolではなく変換後の値なので注意

この例では二次元の配列へと変換できるnumpy.ndarray, tuple, listのみを受け付けています。
自在にロジックがかけるので機械学習でよくある配列のshapeに関するValidationなどが書けるのも魅力的です。

Custom Serializerで完全にSerializeする

AnnotatedPlainSerializerを使ったCustom Validationによってnumpy配列をPydantic-Friendlyな型へと変換することができましたが、これだけではSerializationができません。つまりjsonやdictへの変換ができません。

instance = MyModel(a=3, b=np.array([[1, 2, 3]]))
json = instance.model_dump_json()  # Unable to serialize unknown type: <class 'numpy.ndarray'>

そこでCustomSerializerを実装しましょう。やり方はCustomValidatorと全く同じです。

def serialize(v: np.ndarray, info: SerializationInfo) -> list[list[float]]:
    return v.tolist()

MyArray = Annotated[
    np.ndarray,
    PlainValidator(validate),
    PlainSerializer(serialize),
]

instance = MyModel(a=3, b=np.array([[1, 2, 3]]))
json = instance.model_dump_json()
print(json)  # {"a":3,"b":[[1.0,2.0,3.0]]}

@field_validatorや@model_validatorとの違い

実はAnnotatedを使う他にもfield_validatorやmodel_validatorというデコレータベースの手法もあります。Pydantic v1ではこちらが主流だったと思います。なぜかはわかりませんが、こちらの手法の場合model_configパラメータとしてConfigDict(arbitrary_types_allowed=True)を渡す必要があります。

from typing import Any

import numpy as np
from pydantic import BaseModel, ConfigDict, field_serializer, field_validator


class MyModel(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)
    a: int
    b: np.ndarray

    @field_validator("b")
    @classmethod
    def construct_2d_array(cls, v: Any) -> np.ndarray:
        if isinstance(v, np.ndarray):
            ans = v
        elif isinstance(v, (list, tuple)):
            ans = np.array(v)
        else:
            raise TypeError(
                f"Expected numpy.ndarray, list or tuple of float, got {type(v)}"
            )
        if ans.ndim != 2:
            raise ValueError(f"Expected 2D array, got {ans.ndim}D array")
        return ans

    @field_serializer("b")
    def serialize_2d_array(self, v: np.ndarray) -> list[list[float]]:
        return v.tolist()


instance = MyModel(a=3, b=np.array([[1, 2, 3]]))
json = instance.model_dump_json()
print(json)  # {"a":3,"b":[[1.0,2.0,3.0]]}

しかし、私はこちらよりもAnnotatedを使うケースが有用の場合が多いと思います。なぜならValidationやSerializationロジックは特定のモデルやフィールドではなく型に紐づいていると思うからです。そしてモデルの実装とValidation/Serializationロジックが分離されるのでモデルの実装がすごくすっきりします。

たとえば先ほどのソースコードからValidation/Serializationロジックを他ファイルへと分割すればMyModelクラスの定義はこんなにすっきりです!

from pydantic import BaseModel
from mypkg import MyArray

class MyModel(BaseModel):
    a: int
    b: MyArray

またここにかいたように、一度定義したAnnotatedな型MyArrayはソースコードのどこからでもimportして使えます。@field_validatorはあくまでモデルに紐づいているので、他のモデル定義で同じロジックを使うことはできません。

まとめ

私がPydanticをぐっと好きになった理由の1つはこのAnnotatedによって型にValidation/Serializationロジックを組み込めるようになったところです。直感的でとても素敵ですよね!これらの機能はPydanticのWebアプリケーションとしての応用だけでなく、機械学習領域への応用をも可能にしています。Pydantic v2になりTypeHintを使って安全なコードを書く流れはこれまでのスコープのさらに外へと広がっていくはずです。

みなさんも一緒にType-SafeなPythonの開発を楽しんでいきましょう!

コード全体

from typing import Annotated, Any

import numpy as np
from pydantic import (
    BaseModel,
    PlainSerializer,
    PlainValidator,
)


def validate(v: Any) -> np.ndarray:
    if isinstance(v, np.ndarray):
        ans = v
    elif isinstance(v, (list, tuple)):
        ans = np.array(v)
    else:
        raise TypeError(
            f"Expected numpy.ndarray, list or tuple of float, got {type(v)}"
        )
    if ans.ndim != 2:
        raise ValueError(f"Expected 2D array, got {ans.ndim}D array")
    return ans


def serialize(v: np.ndarray) -> list[list[float]]:
    return v.tolist()


MyArray = Annotated[
    np.ndarray,
    PlainValidator(validate),
    PlainSerializer(serialize),
]


class MyModel(BaseModel):
    a: int
    b: MyArray


instance = MyModel(a=3, b=np.array([[1, 2, 3]]))
json = instance.model_dump_json()
print(json)  # {"a":3,"b":[[1.0,2.0,3.0]]}

Discussion