DataFrame を Validation する pandera 入門
はじめに
Python を用いてデータ分析を行うにあたりよく使われるライブラリとして pandas があります。
pandas は大変使い勝手の良いライブラリですが、多くの場合データを丸ごと pd.DataFrame
型で保持するため「どのような列を持っているのか」、「各列がどのような型か」、「各列の値にどのような値が入りうるのか」等がソースコードを一見しただけでは分からないことが多いです。
結果として処理がブラックボックス化してしまい、デバッグコストの増加やコードの可読性低下といった問題を生じさせることがあります。
この問題への解決策の一つとして、本記事ではデータフレームのバリデーション機能を提供するライブラリである pandera を紹介します。
pandera とは
データ処理パイプラインの可読性とロバストさを高めるために dataframe に対してデータ検証を行う機能を提供するライブラリです。
主に以下の機能を提供します。(上記ドキュメントより引用。一部抜粋。)
- スキーマを定義することで、さまざまなデータフレームの型を検証できる
- データフレームのカラムや値をチェックできる
- pydantic のようなクラスベースの API でスキーマモデルを定義できる
- pydantic, fastapi, mypy と言った Python ツールと統合できる
個人的には pydantic のようにクラスベース API でモデルを定義できる点がありがたいと感じています。
インストール
pip でインストールが可能です。
pip install pandera
使い方
DataFrameSchema によるバリデーション
公式チュートリアルより抜粋します。
import pandas as pd
import pandera as pa
# バリデーション用のデータ
df = pd.DataFrame({
"column1": [1, 4, 0, 10, 9],
"column2": [-1.3, -1.4, -2.9, -10.1, -20.4],
"column3": ["value_1", "value_2", "value_3", "value_2", "value_1"],
})
# スキーマ定義
schema = pa.DataFrameSchema({
"column1": pa.Column(int, checks=pa.Check.le(10)),
"column2": pa.Column(float, checks=pa.Check.lt(-1.2)),
"column3": pa.Column(str, checks=[
pa.Check.str_startswith("value_"),
# series の入力を受け取り boolean か boolean 型の series を返すカスタムチェックメソッドを定義
pa.Check(lambda s: s.str.split("_", expand=True).shape[1] == 2)
]),
})
validated_df = schema(df)
print(validated_df)
column1 column2 column3
0 1 -1.3 value_1
1 4 -1.4 value_2
2 0 -2.9 value_3
3 10 -10.1 value_2
4 9 -20.4 value_1
事前にスキーマを定義しておき、データフレームをスキーマに入力するとバリデーションされたデータフレームが出力されます。
SchemaModel によるバリデーション
続いて SchemaModel を利用する使い方を見てみます。こちらもチュートリアルより抜粋します。
from pandera.typing import Series
class Schema(pa.SchemaModel):
column1: Series[int] = pa.Field(le=10)
column2: Series[float] = pa.Field(lt=-1.2)
column3: Series[str] = pa.Field(str_startswith="value_")
@pa.check("column3")
def column_3_check(cls, series: Series[str]) -> Series[bool]:
"""Check that column3 values have two elements after being split with '_'"""
return series.str.split("_", expand=True).shape[1] == 2
Schema.validate(df)
カスタムチェックメソッドが lambda ではなくデコレータ付きメソッドで実装されていますが、書き方に大きな違いはないことが分かります。
なお pandera.Field
が取る主要な引数には以下があります。
パラメータ | 説明 |
---|---|
nullable | 列に null を許容するか |
unique | 列にユニーク制約を課すか |
coerce | 型を強制するか |
ignore_na | 型チェックの際に null を無視するか |
eq | 指定した値と等しいか |
ge | 指定した値より大きいか |
gt | 指定した値以上か |
le | 指定した値より小さいか |
lt | 指定した値以下か |
ne | 要素を持たないか |
in_range | 指定した最小値、最大値の範囲内か |
isin | 指定したリストの範囲内か |
str_contains | 指定した文字列を含むか |
str_startswith | 指定した文字列から始まるか |
str_endswith | 指定した文字列で終わるか |
str_length | 文字長が指定した最小値、最大値の範囲内か |
実践的な使い方
簡単に使い方を理解したところで Titanic のデータセットを読み込み、加工する処理を試してみます。
データの読み込み
まずはスキーマ定義をしない場合を考えます。
import pandas as pd
def load_data(filepath: str) -> pd.DataFrame:
df = pd.read_csv(filepath)
return df
df = load_data("./train.csv")
ファイルからデータを読み込んでいるので当然と言えば当然ですが、データフレームの中身がどうなっているのかは分かりません。
続いてスキーマ定義をした場合を考えます。
from typing import Optional
import pandas as pd
import pandera as pa
from pandera.typing import Series, DataFrame
class TitanicSchema(pa.SchemaModel):
PassengerId: Series[int] = pa.Field(nullable=False, unique=True)
Survived: Optional[Series[int]] = pa.Field(nullable=True, isin=(0, 1))
Pclass: Series[int] = pa.Field(nullable=False, isin=(1, 2, 3))
Name: Series[str] = pa.Field(nullable=False)
Sex: Series[str] = pa.Field(nullable=False, isin=("male", "female"))
Age: Series[float] = pa.Field(nullable=True, in_range={"min_value": 0, "max_value": 100})
SibSp: Series[int] = pa.Field(nullable=False, ge=0, le=10)
Parch: Series[int] = pa.Field(nullable=False, ge=0, le=10)
Ticket: Series[str] = pa.Field(nullable=False)
Fare: Series[float] = pa.Field(nullable=True, ge=0)
Cabin: Series[str] = pa.Field(nullable=True)
Embarked: Series[str] = pa.Field(nullable=True, str_length=1, isin=("S", "C", "Q"))
class Config:
strict = True
def load_dataset(filepath:str) -> DataFrame[TitanicSchema]:
df = pd.read_csv(filepath)
df = TitanicSchema.validate(df)
return df
df = load_data("./train.csv")
データセットが持つカラムとそれぞれのカラムの情報がスキーマとして定義されたため、データフレームの中身がある程度分かるようになりました。
また、読み込み直後にバリデーションされているため、スキーマとして定義された内容を満たしたデータフレームであることが保証されています。
補足
上記の例では TitanicSchema
クラスのメンバ変数に Config
クラスを定義して strict=True
を設定しています。
このように pa.SchemaModel
は Config
クラスを登録することでメタ情報を定義することができます。
デフォルトではSchemaModel
に登録したカラムが存在しない場合はバリデーションでエラーが出ますが、登録されていないカラムを持っていてもエラーが出ません。
登録されていないカラムを持っていた場合にもエラーを出すために、strict=True
を設定しています。
データの加工
続いてデータを加工してみます。加工のプロセスは下記の notebook を参考にさせていただきました。
本記事においては加工処理そのものは重要ではないので流し読みいただいて問題ありません。
先ほどと同じように、まずはスキーマ定義しない場合を考えます。
import numpy as np
import pandas as pd
def transform(df: pd.DataFrame) -> pd.DataFrame:
df["Sex"] = df["Sex"].str.match("male").map(int)
df["Title"] = df["Name"].str.extract(" ([A-Za-z]+)\.", expand=False)
df["Title"] = df["Title"].replace(
["Lady", "Countess", "Capt", "Col", "Don", "Dr", "Major", "Rev", "Sir", "Jonkheer", "Dona"], "Rare"
)
df["Title"] = df["Title"].replace("Mlle", "Miss")
df["Title"] = df["Title"].replace("Ms", "Miss")
df["Title"] = df["Title"].replace("Mme", "Mrs")
df["Title"] = df["Title"].map({"Mr": 1, "Miss": 2, "Mrs": 3, "Master": 4, "Rare": 5})
guess_ages = np.zeros((2, 3))
for i in range(2):
for j in range(3):
guess_df = df[(df["Sex"] == i) & (df["Pclass"] == j + 1)]["Age"].dropna()
age_guess = guess_df.median()
guess_ages[i, j] = int(age_guess / 0.5 + 0.5) * 0.5
for i in range(2):
for j in range(3):
df.loc[(df["Age"].isnull()) & (df["Sex"] == i) & (df["Pclass"] == j + 1), "Age"] = guess_ages[i, j]
df["Age"] = df["Age"].astype(int)
df.loc[df["Age"] <= 16, "Age"] = 0
df.loc[(df["Age"] > 16) & (df["Age"] <= 32), "Age"] = 1
df.loc[(df["Age"] > 32) & (df["Age"] <= 48), "Age"] = 2
df.loc[(df["Age"] > 48) & (df["Age"] <= 64), "Age"] = 3
df.loc[df["Age"] > 64, "Age"] = 4
df["FamilySize"] = df["SibSp"] + df["Parch"] + 1
df["IsAlone"] = df["FamilySize"].map(lambda x: 1 if x == 1 else 0)
df = df.drop(["Ticket", "Cabin", "PassengerId", "Name"], axis=1)
return df
df = load_data("./train.csv")
df = transform(df)
加工は欠損補完、ビニング、型変換、列同士の演算、不要列の削除などさまざまな処理が含まれます。
これらの処理を経由した結果、最終的に得られるデータフレームがどのような状態になっているのかがすぐには分かり辛いと思います。
続いてスキーマ定義をした場合を考えます。
class TransformedTitanicSchema(pa.SchemaModel):
Survived: Optional[Series[int]] = pa.Field(nullable=True, isin=(0, 1))
Pclass: Series[int] = pa.Field(nullable=False, isin=(1, 2, 3))
Sex: Series[int] = pa.Field(nullable=False, isin=(0, 1))
Age: Series[int] = pa.Field(nullable=False, isin=(0, 1, 2, 3, 4))
SibSp: Series[int] = pa.Field(nullable=False, ge=0, le=10)
Parch: Series[int] = pa.Field(nullable=False, ge=0, le=10)
Fare: Series[float] = pa.Field(nullable=True, ge=0)
Embarked: Series[str] = pa.Field(nullable=True, str_length=1, isin=("S", "C", "Q"))
Title: Series[int] = pa.Field(nullable=False, isin=(1, 2, 3, 4, 5))
FamilySize: Series[int] = pa.Field(nullable=False, ge=0, le=15)
IsAlone: Series[int] = pa.Field(nullable=False, isin=(0, 1))
class Config:
strict = True
def transform(df: DataFrame[TitanicSchema]) -> DataFrame[TransformedTitanicSchema]:
df["Sex"] = df["Sex"].str.match("male").map(int)
df["Title"] = df["Name"].str.extract(" ([A-Za-z]+)\.", expand=False)
df["Title"] = df["Title"].replace(
["Lady", "Countess", "Capt", "Col", "Don", "Dr", "Major", "Rev", "Sir", "Jonkheer", "Dona"], "Rare"
)
df["Title"] = df["Title"].replace("Mlle", "Miss")
df["Title"] = df["Title"].replace("Ms", "Miss")
df["Title"] = df["Title"].replace("Mme", "Mrs")
df["Title"] = df["Title"].map({"Mr": 1, "Miss": 2, "Mrs": 3, "Master": 4, "Rare": 5})
guess_ages = np.zeros((2, 3))
for i in range(2):
for j in range(3):
guess_df = df[(df["Sex"] == i) & (df["Pclass"] == j + 1)]["Age"].dropna()
age_guess = guess_df.median()
guess_ages[i, j] = int(age_guess / 0.5 + 0.5) * 0.5
for i in range(2):
for j in range(3):
df.loc[(df["Age"].isnull()) & (df["Sex"] == i) & (df["Pclass"] == j + 1), "Age"] = guess_ages[i, j]
df["Age"] = df["Age"].astype(int)
df.loc[df["Age"] <= 16, "Age"] = 0
df.loc[(df["Age"] > 16) & (df["Age"] <= 32), "Age"] = 1
df.loc[(df["Age"] > 32) & (df["Age"] <= 48), "Age"] = 2
df.loc[(df["Age"] > 48) & (df["Age"] <= 64), "Age"] = 3
df.loc[df["Age"] > 64, "Age"] = 4
df["FamilySize"] = df["SibSp"] + df["Parch"] + 1
df["IsAlone"] = df["FamilySize"].map(lambda x: 1 if x == 1 else 0)
df = df.drop(["Ticket", "Cabin", "PassengerId", "Name"], axis=1)
df = TransformedTitanicSchema.validate(df)
return df
df = load_data("./train.csv")
df = transform(df)
加工処理は変わっていませんが、処理の中身を完全に読み解かなくてもある程度どのような値が入るかが分かるようになりました。
また、加工処理終了直後にデータのバリデーションをおこなっているので、加工によって想定外の値が混入しないことが保証されるようになりました。
読み込みと加工
まとめると、下記の通りになります。
なお、上記のソースコードはメソッドの最後に SchemaModel.validate
を実行して型の確認をしていましたが、これを毎回書くのは面倒です。
タイプヒントのついたメソッドにデコレータ @pa.check_types
を付与することで自動的にバリデーションを実施してくれるようになります。
import numpy as np
import pandas as pd
import pandera as pa
from typing import Optional
from pandera.typing import Series, DataFrame
class TitanicSchema(pa.SchemaModel):
PassengerId: Series[int] = pa.Field(nullable=False, unique=True)
Survived: Optional[Series[int]] = pa.Field(nullable=True, isin=(0, 1))
Pclass: Series[int] = pa.Field(nullable=False, isin=(1, 2, 3))
Name: Series[str] = pa.Field(nullable=False)
Sex: Series[str] = pa.Field(nullable=False, isin=("male", "female"))
Age: Series[float] = pa.Field(nullable=True, in_range={"min_value": 0, "max_value": 100})
SibSp: Series[int] = pa.Field(nullable=False, ge=0, le=10)
Parch: Series[int] = pa.Field(nullable=False, ge=0, le=10)
Ticket: Series[str] = pa.Field(nullable=False)
Fare: Series[float] = pa.Field(nullable=True, ge=0)
Cabin: Series[str] = pa.Field(nullable=True)
Embarked: Series[str] = pa.Field(nullable=True, str_length=1, isin=("S", "C", "Q"))
class Config:
strict = True
class TransformedTitanicSchema(pa.SchemaModel):
Survived: Optional[Series[int]] = pa.Field(nullable=True, isin=(0, 1))
Pclass: Series[int] = pa.Field(nullable=False, isin=(1, 2, 3))
Sex: Series[int] = pa.Field(nullable=False, isin=(0, 1))
Age: Series[int] = pa.Field(nullable=True, isin=(0, 1, 2, 3, 4))
SibSp: Series[int] = pa.Field(nullable=False, ge=0, le=10)
Parch: Series[int] = pa.Field(nullable=False, ge=0, le=10)
Fare: Series[float] = pa.Field(nullable=True, ge=0)
Embarked: Series[str] = pa.Field(nullable=True, str_length=1, isin=("S", "C", "Q"))
Title: Series[int] = pa.Field(nullable=False, isin=(1, 2, 3, 4, 5))
FamilySize: Series[int] = pa.Field(nullable=False, ge=0, le=15)
IsAlone: Series[int] = pa.Field(nullable=False, isin=(0, 1))
class Config:
strict = True
@pa.check_types
def load_dataset(filepath: str) -> DataFrame[TitanicSchema]:
df = pd.read_csv(filepath)
return df
@pa.check_types
def transform(df: DataFrame[TitanicSchema]) -> DataFrame[TransformedTitanicSchema]:
df["Sex"] = df["Sex"].str.match("male").map(int)
df["Title"] = df["Name"].str.extract(" ([A-Za-z]+)\.", expand=False)
df["Title"] = df["Title"].replace(
["Lady", "Countess", "Capt", "Col", "Don", "Dr", "Major", "Rev", "Sir", "Jonkheer", "Dona"], "Rare"
)
df["Title"] = df["Title"].replace("Mlle", "Miss")
df["Title"] = df["Title"].replace("Ms", "Miss")
df["Title"] = df["Title"].replace("Mme", "Mrs")
df["Title"] = df["Title"].map({"Mr": 1, "Miss": 2, "Mrs": 3, "Master": 4, "Rare": 5})
guess_ages = np.zeros((2, 3))
for i in range(2):
for j in range(3):
guess_df = df[(df["Sex"] == i) & (df["Pclass"] == j + 1)]["Age"].dropna()
age_guess = guess_df.median()
guess_ages[i, j] = int(age_guess / 0.5 + 0.5) * 0.5
for i in range(2):
for j in range(3):
df.loc[(df["Age"].isnull()) & (df["Sex"] == i) & (df["Pclass"] == j + 1), "Age"] = guess_ages[i, j]
df["Age"] = df["Age"].astype(int)
df.loc[df["Age"] <= 16, "Age"] = 0
df.loc[(df["Age"] > 16) & (df["Age"] <= 32), "Age"] = 1
df.loc[(df["Age"] > 32) & (df["Age"] <= 48), "Age"] = 2
df.loc[(df["Age"] > 48) & (df["Age"] <= 64), "Age"] = 3
df.loc[df["Age"] > 64, "Age"] = 4
df["FamilySize"] = df["SibSp"] + df["Parch"] + 1
df["IsAlone"] = df["FamilySize"].map(lambda x: 1 if x == 1 else 0)
df = df.drop(["Ticket", "Cabin", "PassengerId", "Name"], axis=1)
return df
def main():
df = load_dataset("./train.csv")
df = transform(df)
if __name__ == "__main__":
main()
終わりに
本記事ではデータフレームのバリデーション機能を提供するライブラリである pandera を紹介しました。
単純にソースコードの分量だけを見ると倍近くに増えており、コストなくバリデーションや型チェックができるわけではないので、短期間で破棄される前提のソースコードに対しては使う価値はないかもしれません。複数人で長期的にメンテナンスしていく必要のあるソースコードに対しては非常に効果的であるように感じました。
機械学習モデルがプロダクション環境で動くことが比較的当たり前になってきたことから dataclass や pydantic 等のスキーマ定義、タイプヒントを開発に導入している企業が増えている印象です。もしデータフレームにタイプヒントや型チェックが付けられなくて困っているようであれば、 pandera の利用を検討してみるのも一つの手かと思います。
その際に本記事がお役に立てば光栄です。最後までお読みいただきありがとうございました。
参考資料
リポジトリ
ドキュメント
日本語の紹介記事
Discussion