🔺

RootModelのすすめ【Pydantic】

2024/12/25に公開

はじめに

こんにちは!Pydanticしてますか?
タイプヒント・バリデーション・シリアライズととにかく便利なPydanticですが、RootModelがかなり便利だったので紹介したいと思います!

https://docs.pydantic.dev/latest/api/root_model/

RootModelを使うと何ができるの

自前のクラスでリストや辞書をラップしたもの(コレクションオブジェクト)を直感的に作ることができます!

BaseModelの場合

from pydantic import BaseModel

class MyListInt(BaseModel):
    list_int: list[int]

リストにキーを付けていれるイメージ。

import json

data = json.loads("{ \"list_int\": [1, 2, 3] }") # ここが直感的でない
my_list_int = MyListInt(**data)
print(my_list_int)
# list_int=[1, 2, 3]

RootModelの場合

from pydantic import RootModel

class MyListInt(RootModel):
    root: list[int]

# ジェネリックを使った定義も可能
# class MyListInt(RootModel[list[int]]):
#    pass

リストをそのままいれることができるイメージ。

import json

data = json.loads("[1, 2, 3]") # ここが直感的
my_list_int = MyListInt(data)
print(my_list_int)
# root=[1, 2, 3]

これだけだと恩恵が感じられないので、複雑な例も見てみましょう

RootModelを使うと便利なところ

以下のようなチャット履歴を読み込む場合を考えます。roleuserassistantが入るものとします。

{
    "user_id": "123",
    "chat_history": [
        {"message": "Hello!", "role": "user"},
        {"message": "Hi there!", "role": "assistant"}
    ]
}

BaseModelを使うと不便

BaseModelを使うと次のようになるかと思います。

from enum import Enum
from pydantic import BaseModel


class Role(Enum):
    USER = "user"
    ASSISTANT = "assistant"


class Chat(BaseModel):
    message: str
    role: Role


class UserData(BaseModel):
    user_id: str
    chat_history: list[Chat]
import json

json_data = json.loads(
    """
        {
            "user_id": "123",
            "chat_history": [
                {"message": "Hello!", "role": "user"},
                {"message": "Hi there!", "role": "assistant"}
            ]
        }
    """
)
user_data = UserData(**json_data)

print(user_data)
# user_id='123' chat_history=[Chat(message='Hello!', role=<Role.USER: 'user'>), Chat(message='Hi there!', role=<Role.ASSISTANT: 'assistant'>)]

これは問題なく動作しますが、chat_historylist[Chat]としてしまうと、これ自体にメソッドをはやすことができず扱いづらくなってしまいます。

一方で、次のようにBaseModelを1つ増やした場合、扱いやすくはなりますがデータの受け渡し方が変わってしまいます。

from enum import Enum
from pydantic import BaseModel


class Role(Enum):
    USER = "user"
    ASSISTANT = "assistant"


class Chat(BaseModel):
    message: str
    role: Role


class ChatHistory(BaseModel): # 追加
    chat_history: list[Chat]


class UserData(BaseModel):
    user_id: str
    chat_history: ChatHistory

データがそのままだとエラーになります。

import json

json_data = json.loads(
    """
        {
            "user_id": "123",
            "chat_history": [
                {"message": "Hello!", "role": "user"},
                {"message": "Hi there!", "role": "assistant"}
            ]
        }
    """
)
user_data = UserData(**json_data)
# pydantic_core._pydantic_core.ValidationError: 1 validation error for UserData
# chat_history
#  Input should be a valid dictionary or instance of ChatHistory [type=model_type, input_value=[{'message': 'Hello!', 'r...', 'role': 'assistant'}], input_type=list]
#    For further information visit https://errors.pydantic.dev/2.9/v/model_type

print(user_data)

エラーを解消するにはデータを1段階ネストする必要があります。

import json

json_data = json.loads(
    """
        {
            "user_id": "123",
            "chat_history": {
                "chat_history": [
                    {"message": "Hello!", "role": "user"},
                    {"message": "Hi there!", "role": "assistant"}
                ]
            }
        }
    """
)
user_data = UserData(**json_data)

print(user_data)
# user_id='123' chat_history=ChatHistory(chat_history=[Chat(message='Hello!', role=<Role.USER: 'user'>), Chat(message='Hi there!', role=<Role.ASSISTANT: 'assistant'>)])

再度、jsonへシリアル化する際も成形する必要があります。(chat_historyの中にchat_historyがある)

print(user_data.model_dump_json())
# {"user_id":"123","chat_history":{"chat_history":[{"message":"Hello!","role":"user"},{"message":"Hi there!","role":"assistant"}]}}

validatorやserializerを駆使して何とかすることもできますが、労力がかかります。

RootModelを使うと便利

先ほどのChatHistoryをRootModelにすると、データをそのままにコレクションオブジェクトとして扱うことができます。

from enum import Enum
from pydantic import BaseModel, RootModel


class Role(Enum):
    USER = "user"
    ASSISTANT = "assistant"


class Chat(BaseModel):
    message: str
    role: Role


class ChatHistory(RootModel):
    root: list[Chat]


class UserData(BaseModel):
    user_id: str
    chat_history: ChatHistory
import json
json_data = json.loads('''
{
    "user_id": "123",
    "chat_history": [
        {"message": "Hello!", "role": "user"},
        {"message": "Hi there!", "role": "assistant"}
    ]
}
''')
user_data = UserData(**json_data)

print(user_data)
# user_id='123' chat_history=ChatHistory(root=[Chat(message='Hello!', role=<Role.USER: 'user'>), Chat(message='Hi there!', role=<Role.ASSISTANT: 'assistant'>)])

データにアクセスする際は、user_data.chat_history.rootとなることに注意してください。
再度、jsonへシリアル化する場合も元の形式を維持することができます。

print(user_data.model_dump_json())
# {"user_id":"123","chat_history":[{"message":"Hello!","role":"user"},{"message":"Hi there!","role":"assistant"}]}

おわりに

いかがだったでしょうか?
あまり使われていなそうなRootModelですが、痒いところに手が届いて便利ですよね。
みなさまのPydanticライフに少しでもお役に立てれば幸いです。

Discussion