🐍

ファイルデータの読み込み時点でドメインの型制約を保証する

に公開

TL;DR

  1. ファイルから情報を読み込む際に、ドメインモデルによるバリデーションを行いたい。さらに、パス情報も取得・保持したいし、DataFrameへのエクスポートも行いたい。
  2. 既存のパッケージでは、読み込み後にバリデーションするか、またはDataFrame用に別途モデル定義が必要。
  3. そのため、自作しました。

動機

Pythonでpandasやpolars、あるいは単純なListに対してファイルからデータを読み込む際、Pydantic製のドメインモデルによるバリデーションを行いたいと考えていました。

例えば、以下のようにPolarsでcsvファイルを読み込む場合、

import polars
df: polars.DataFrame = polars.read_csv(source="path_to_read.csv")

PolarsのDataFrameはドメイン知識を持っていないため、
ドメインモデルと整合しないデータであっても読み込めてしまうリスクがあります。

たとえば、次のようなモデルがあったとします。

from pydantic import BaseModel

class JobApplicant(BaseModel):
    uuid: str
    name: str
    age: int

このモデルを意図してファイルを読み込んだとき、nameがstrではなくintであった場合などもエラーなく読み込めてしまい、結果として分析時に不具合を引き起こす可能性があるのです。

既存ライブラリとの違い

同様の課題に取り組んでいるパッケージとしてpanderaがあります。たしかに、penderaを使えば、バリデーションそのものは可能です。しかし、私の要件は**「外部からPydanticモデルを渡して再利用できる」**ことでした。

panderaの使用例:

import pandera as pa
from pandera.typing import Series

class Schema(pa.DataFrameModel):
    column1: int = pa.Field(le=10)
    column2: float = pa.Field(lt=-1.2)
    column3: str = pa.Field(str_startswith="value_")

    @pa.check("column3")
    def column_3_check(cls, series: Series[str]) -> Series[bool]:
        return series.str.split("_", expand=True).shape[1] == 2

Schema.validate(df)

この場合、新たにpandera用のモデルを定義し直す必要があり、Pydanticのドメインモデルをそのまま使い回すことができません。

また、

  1. 複数ファイルを読み込んでまとめたい
  2. どのファイル由来か追跡できるようにしたい
  3. PolarsやPandasにエクスポートしたい

といった要件もあり、既存ライブラリでは満たせないため、自作することにしました。

成果物

というわけで欲望を詰めこんだパッケージを自作しました。上記の内容は全て実現しています。

GitHub Repository

https://github.com/shunsock/fukinotou

⭐ フィードバックやスターをいただけると、今後の開発の励みになります ⭐

何ができるか

次のようなドメインモデルを定義して、

class User(BaseModel):
    id: int
    name: str
    age: int

読み込み時に、Pydanticモデルによるバリデーションを行うことができます。ここでは簡単なモデルを利用していますが、もっと複雑なバリデーションも可能です。Userモデルであれば、idやageをPositiveIntにしたり、nameの文字数に制限を掛けるといったバリデーションが考えられます。

try:
    users: CsvLoaded[User] = CsvLoader(User).load("./users.csv")
    print(users)
    print(users.value[1].value.name)
except LoadingException as e:
    print(f"Failed to load: {e}")

# ✅ domain model(User class)による読みこまれた値の保証
# path=PosixPath('users.csv') value=[CsvRow(path=PosixPath('users.csv'),  value=User(id=1, name='shunsock', age=24)), CsvRow(path=PosixPath('users.csv'), value=User(id=2, name='shunsuke', age=24))]
# shunsuke

さらに、PolarsやPandasなどのDataFrameにも簡単に変換可能です。

users_dataframe = users.to_polars(include_path_as_column=True)
print(users_dataframe)

# shape: (2, 4)
┌─────┬──────────┬─────┬───────────┐
│ id  ┆ name     ┆ age ┆ path      │
│ ------------       │
│ i64 ┆ str      ┆ i64 ┆ str       │
╞═════╪══════════╪═════╪═══════════╡
│ 1   ┆ shunsock ┆ 24  ┆ users.csv │
│ 2   ┆ shunsuke ┆ 24  ┆ users.csv │
└─────┴──────────┴─────┴───────────┘

もし異常なデータが含まれていた場合は、Validationエラーを検知して例外を発生させます。

try:
    users: CsvLoaded[User] = CsvLoader(User).load("./users__invalid.csv")
except LoadingException as e:
    print(f"Failed to load: {e}")

# 🚨 domain model(User class)による異常値の検知
# Failed to load: Validation Error: details Error reading file users__invalid.csv: Validation Error: details Error parsing row 3 in users__invalid.csv: 1 validation error for User
# age
#  Input should be a valid integer, unable to parse string as an integer [type=int_parsing, input_value='二十四', input_type=str]

実現方法

せっかくなので作り方も解説します。解説パートはPythonの中級者以上を対象としているため、補足を巻末に付けました。必要に応じてご参照ください。

解説するコードの全体像

動かしたい方もいると思うので先に全体像を貼っておきます。

csv_loader.py
import csv
from pathlib import Path
from typing import Dict, List, Type, TypeVar, Generic, Iterator

from pydantic import BaseModel, ValidationError

from .exception.loading_exception import LoadingException
from .abstraction.dataframe_exportable import DataframeExportable

T = TypeVar("T", bound=BaseModel)


class CsvRow(BaseModel, Generic[T]):
    path: Path
    value: T


class CsvLoaded(
    BaseModel,
    Generic[T],
    DataframeExportable[CsvRow[T]],
):
    path: Path
    value: List[CsvRow[T]]


class CsvLoader(Generic[T]):
    def __init__(self, model: Type[T]) -> None:
        self.model = model

    def load(self, path: str | Path, encoding: str = "utf-8") -> CsvLoaded[T]:
        p = Path(path)
        if not p.is_file():
            raise LoadingException(f"Input path is invalid: {p}")

        try:
            with p.open("r", encoding=encoding) as f:
                reader = csv.reader(f)
                headers = self._read_csv_headers(reader, p)
                csv_rows = self._validate_csv_row(reader, headers, p)
                return CsvLoaded(path=p, value=csv_rows)
        except Exception as e:
            raise LoadingException(
                original_exception=e, error_message=f"Error reading file {p}: {e}"
            )

    @staticmethod
    def _read_csv_headers(reader: Iterator[List[str]], path: Path) -> List[str]:
        try:
            headers: List[str] = next(reader)
            return headers
        except StopIteration:
            raise LoadingException(
                original_exception=None, error_message=f"No headers found in {path}"
            )

    def _validate_csv_row(
        self, reader: Iterator[List[str]], headers: List[str], path: Path
    ) -> List[CsvRow[T]]:
        csv_rows: List[CsvRow[T]] = []
        # Validation foreach rows
        for row_number, row_data in enumerate(reader, start=2):
            row_data_typed: List[str] = row_data
            # Skip empty lines
            if not any(cell.strip() for cell in row_data_typed):
                continue

            # Validation
            row_dict: Dict[str, str] = {}
            for i, header in enumerate(headers):
                if i < len(row_data_typed):
                    row_dict[header] = row_data_typed[i]

            try:
                csv_rows.append(
                    CsvRow(path=path, value=self.model.model_validate(row_dict))
                )
            except ValidationError as e:
                raise LoadingException(
                    original_exception=e,
                    error_message=f"Error parsing row {row_number} in {path}: {e}",
                )
        return csv_rows
dataframe_exportable.py
from pathlib import Path

from typing import Generic, TypeVar, List, Dict, Any, Protocol
from pydantic import BaseModel

import polars
import pandas

V = TypeVar("V", bound=BaseModel)


class Row(Protocol[V]):
    value: V
    path: Path


T = TypeVar("T", bound=Row[Any])


class DataframeExportable(Generic[T]):
    path: Path
    value: List[T]

    def _to_dicts(self, use_path: bool) -> List[Dict[str, Any]]:
        if not use_path:
            return [v.value.model_dump() for v in self.value]
        return [{**v.value.model_dump(), "path": str(v.path)} for v in self.value]

    def to_polars(self, include_path_as_column: bool = False) -> polars.DataFrame:
        if not self.value:
            return polars.DataFrame()

        df = polars.DataFrame(self._to_dicts(include_path_as_column))

        return df

    def to_pandas(self, include_path_as_column: bool = False) -> pandas.DataFrame:
        if not self.value:
            return pandas.DataFrame()

        df = pandas.DataFrame(self._to_dicts(include_path_as_column))

        return df

概要

ここでは、csvファイルを読みこみと同時に型検査する型を作ります。
といっても型検査をする実態はpydantic.BaseModelのサブクラスなので、我々は

  1. データを読みこむ。
  2. それぞれの行の値をpydantic.BaseModelのオブジェクトにキャスト (ここでValidationが走る)、オブジェクトとして保存。
  3. キャストしたデータをvalue, 読みこみ先のパスをpathとして持つクラスを作れば十分です。

行をPydanticのモデルとして表現する

それぞれの行の値をpydantic.BaseModelのオブジェクトにキャストします。ここでは、value attributeにキャストした値を格納することにします。

T = TypeVar("T", bound=BaseModel)


class CsvRow(BaseModel, Generic[T]):
    path: Path
    value: T # UserオブジェクトなどのPydanticModelが入る

さきほどでてきた、TypeVarがあります。型パラメータは 上限 (bound) を設定可能です。上限を設定するとそのサブクラスのみが型引数の対象となります。我々はpydantic.BaseModelにValidationをまかせるので、pydantic.BaseModelのサブクラスであるという条件を入れています。

Csv全体を扱うモデルを作る

次に全体を管理するオブジェクトを作成します。今回はCsvというモデルは作成せず、単にRowのあつまりと表現しました。

class CsvLoaded(
    BaseModel,
    Generic[T],
    DataframeExportable[CsvRow[T]],
):
    path: Path
    value: List[CsvRow[T]]

ロードするメソッドの実装

parseするときにメソッドを呼び出しをする関係上Tで渡す型の実態が必要です。初期化時に渡すことに気をつけてください。

class CsvLoader(Generic[T]):
    def __init__(self, model: Type[T]) -> None:
        self.model = model

    def load(self, path: str | Path, encoding: str = "utf-8") -> CsvLoaded[T]:
        p = Path(path)
        if not p.is_file():
            raise LoadingException(f"Input path is invalid: {p}")

        try:
            with p.open("r", encoding=encoding) as f:
                reader = csv.reader(f)
                headers = self._read_csv_headers(reader, p) # 実装がんばる
                csv_rows = self._validate_csv_row(reader, headers, p) # 実装がんばる
                return CsvLoaded(path=p, value=csv_rows)
        except Exception as e:
            raise LoadingException(
                original_exception=e, error_message=f"Error reading file {p}: {e}"
            )

実装がんばるのところを実装したものが次のコードになります。

プライベートメソッドの実装
    @staticmethod
    def _read_csv_headers(reader: Iterator[List[str]], path: Path) -> List[str]:
        try:
            headers: List[str] = next(reader)
            return headers
        except StopIteration:
            raise LoadingException(
                original_exception=None, error_message=f"No headers found in {path}"
            )

    def _validate_csv_row(
        self, reader: Iterator[List[str]], headers: List[str], path: Path
    ) -> List[CsvRow[T]]:
        csv_rows: List[CsvRow[T]] = []
        # Validation foreach rows
        for row_number, row_data in enumerate(reader, start=2):
            row_data_typed: List[str] = row_data
            # Skip empty lines
            if not any(cell.strip() for cell in row_data_typed):
                continue

            # Validation
            row_dict: Dict[str, str] = {}
            for i, header in enumerate(headers):
                if i < len(row_data_typed):
                    row_dict[header] = row_data_typed[i]

            try:
                csv_rows.append(
                    CsvRow(path=path, value=self.model.model_validate(row_dict))
                )
            except ValidationError as e:
                raise LoadingException(
                    original_exception=e,
                    error_message=f"Error parsing row {row_number} in {path}: {e}",
                )
        return csv_rows

参考: https://peps.python.org/pep-0484/

今後の展望

仕事で利用するために作ったので開発を続ける予定です。

次にやることとしては、Cloud Storageやs3からデータを読みこみを考えています。このパッケージの開発を通じて、すぐに安心して使えるデータ読みこみ機能を提供していく所存です。

補足

TypeVar, Generics

Pythonを使うとlist[str]は自然に使う型です。この型は、intstr など型を引数として受けとります。次のコードは str を引数として受けとる事例です。

from typing import List
names: List[str] = ["John", "Mary"]
names.push(1) # error!!

型引数を持つ型のことを「ジェネリック型(generic types)」と呼びます。例えば、任意の型引数Tを持つStackを定義するには次のようにします。

from typing import TypeVar, Generic

T = TypeVar('T')

class Stack(Generic[T]):
    def __init__(self) -> None:
        # Create an empty list with items of type T
        self.items: list[T] = []

    def push(self, item: T) -> None:
        self.items.append(item)

    def pop(self) -> T:
        return self.items.pop()

    def empty(self) -> bool:
        return not self.items

このように定義すれば、次のような型検査が可能になります。

# Construct an empty Stack[int] instance
stack = Stack[int]()
stack.push(2)
stack.pop() + 1
stack.push('x')  # error: Argument 1 to "push" of "Stack" has incompatible type "str"; expected "int"

引用: https://typing.python.org/en/latest/reference/generics.html

Protocol

ProtocolはPythonにおけるDuck Typing(構造的部分型付け)の機能です。methodやattributeに対する制約を記述可能です。Duck Typingについて知りたい型は次のWikiを参照してください。

https://ja.wikipedia.org/wiki/ダック・タイピング

Protocolは暗黙的にも明示的にも利用可能です。例えば、次のコードは暗黙的な利用事例です。

def close_all(things: Iterable[SupportsClose]) -> None:
    for t in things:
        t.close()

f = open('foo.txt')
r = Resource()
close_all([f, r])  # OK!
close_all([1])     # Error: 'int' has no 'close' method

一方で、次の事例では、明示的にProtocolのClassを継承しています。

class PColor(Protocol):
    @abstractmethod
    def draw(self) -> str:
        ...
    def complex_method(self) -> int:
        # some complex code here
        ...

class NiceColor(PColor):
    def draw(self) -> str:
        return "deep blue"

class BadColor(PColor):
    def draw(self) -> str:
        return super().draw()  # Error, no default implementation

class ImplicitColor:   # Note no 'PColor' base here
    def draw(self) -> str:
        return "probably gray"
    def complex_method(self) -> int:
        # class needs to implement this
        ...

nice: NiceColor
another: ImplicitColor

def represent(c: PColor) -> None:
    print(c.draw(), c.complex_method())

represent(nice) # OK
represent(another) # Also OK

引用: https://peps.python.org/pep-0544/

Discussion