🐍

Pydantic の BaseModel を継承したクラスに AsyncMock を渡す

2023/06/01に公開

例えば Pydantic を使ってこんな感じでクラスを定義します。

from pydantic import BaseModel


class Hello(BaseModel):
    async def say():
        print("hello")

class Greeter(BaseModel):
    hello: Hello
    
    async def greet(self):
        self.hello.say()

Greeter を初期化するときに Greeter(hello=Hello()) のようなコードを書きますが、テストで Hello クラスをモックしたい場合 AsyncMock を使います。

AsyncMock は、Python 3.8 以降で利用できる非同期関数をモックするためのクラスです。非同期関数を呼び出す await 式を用いて呼び出すことができます。

from unittest.mock import AsyncMock

mock = AsyncMock(Hello)
greeter = Greeter(hello=mock)

この時、コードだけを読むと AsyncMock が渡ってそうですが、なんと実際は MagicMock が渡っています。

print(mock)    # <AsyncMock spec='Hello' id='281473807317200'>
print(greeter) # hello=<MagicMock name='mock._copy_and_set_values()' id='281473803463568'>

なので次のコードを実行すると TypeError: object MagicMock can't be used in 'await' expression といった例外が発生して実行に失敗します。

import asyncio

async def main():
    await greeter.greet()

asyncio.run(main())

見事に筆者はこれにハマってしまい 1 時間くらい時間を取られました。

失敗するコードの全体
from pydantic import BaseModel

class Hello(BaseModel):
    async def say(self):
        print("hello")

class Greeter(BaseModel):
    hello: Hello

    async def greet(self):
        await self.hello.say()

from unittest.mock import AsyncMock

mock = AsyncMock(Hello)
greeter = Greeter(hello=mock)


print(mock)    # <AsyncMock spec='Hello' id='281473807317200'>
print(greeter) # hello=<MagicMock name='mock._copy_and_set_values()' id='281473803463568'>

import asyncio

async def main():
    await greeter.greet()

asyncio.run(main())

回避方法

How to mock a pydantic BaseClass? で解決策を発見しました。

from unittest.mock import AsyncMock

mock = AsyncMock(Hello)
+ mock._copy_and_set_values.return_value = mock
greeter = Greeter(hello=mock)

このように変更するだけでちゃんと AsyncMock のまま値が渡り、例外が発生しないので動くようになります。

<AsyncMock spec='Hello' id='281473004374544'>
hello=<AsyncMock spec='Hello' id='281473004374544'>

Discussion