Django ORMでN+1問題を検知する方法
はじめに
これは Django Advent Calendar 2021 14日目の記事です(空いていたので飛び入りで)。
ORMを使うときに注意が必要なのが、N+1問題と呼ばれるものです。これは、外部キーを使ってJOINで結合すれば1回のSQLで全て取得できるものが、コードの書き方がまずくて1回 + N回(レコード数)のSQLが発行されるという問題です。その原理から、N+1問題ではなく、1+N問題と呼んだ方がいいという話もあります。
この問題を解消するために、Djangoではselect_related()というQuerySetのメソッドが提供されています。これに外部キーを指定することでSQLのJOINで別テーブルの内容を取ってきます。外部キーの先の外部キーも指定可能です。
一方で、N+1問題を検知するのは難しいです。定番は Django Debug Toolbar を使う方法です。アクセスすると、次のような画面が出ます。これを使ってN+1問題を検知することができます。
これは確かに便利ですが、画面を表示しないと確認できないので、どうしても確認漏れが出てしまいます。CIで検知できればベストでしょう。
CIで検知できる方法はないか。そう考えて色々試してみた結果、簡単に組み込めるN+1検知のコードを見つけることができました。
ただ強力すぎて使いづらいかもしれませんが・・・
N+1問題を検知するコード
次のコードを入れれば、N+1問題を検知できます。
from unittest import mock
import pytest
from django.db.models.fields.related_descriptors import ForwardManyToOneDescriptor
@pytest.fixture(autouse=True)
def mock_nplusone():
mock.patch.object(ForwardManyToOneDescriptor, "get_object", side_effect=Exception)
あるいは次のようにコンテキストマネージャとして利用する手もあります。
with mock.patch.object(ForwardManyToOneDescriptor, "get_object", side_effect=Exception):
mockの使い方が良くないのか、これでも検知できないものもあるようですが、原理的には ForwardManyToOneDescriptor.get_object()
を潰すことで対応できます。
N+1問題があると、次のようなメッセージが出て落ちます。この場合、 self.project.mode
の箇所でN+1問題が起きています。
apps/task/next_action/models.py:248: in current_mode
elif self.project.mode:
venv/lib/python3.9/site-packages/django/db/models/fields/related_descriptors.py:187: in __get__
rel_obj = self.get_object(instance)
/opt/homebrew/Cellar/python@3.9/3.9.9/Frameworks/Python.framework/Versions/3.9/lib/python3.9/unittest/mock.py:1092: in __call__
return self._mock_call(*args, **kwargs)
/opt/homebrew/Cellar/python@3.9/3.9.9/Frameworks/Python.framework/Versions/3.9/lib/python3.9/unittest/mock.py:1096: in _mock_call
return self._execute_mock_call(*args, **kwargs)
例えば、先ほどのエラーは次のように書いている箇所で起きていました(正確にはprintではないですが)。
setting = Settings.objects.first()
print(setting.mode)
これは setting.mode
を取得する時点でクエリが走ります。そのため次のように書くといいです。
setting = Setting.objects.select_related("mode").first()
print(setting.mode)
ちょっとめんどいですが、SQL発行回数は1つ減りました。よしとしましょう。
N+1問題が検知できる理由
なぜ ForwardManyToOneDescriptor.get_object()
を潰せばN+1問題が検知できるのでしょうか。それは、ForwardManyToOneDescriptor
の実装にあります。
def get_object(self, instance):
qs = self.get_queryset(instance=instance)
# Assuming the database enforces foreign keys, this won't fail.
return qs.get(self.field.get_reverse_related_filter(instance))
def __get__(self, instance, cls=None):
"""
Get the related instance through the forward relation.
With the example above, when getting ``child.parent``:
- ``self`` is the descriptor managing the ``parent`` attribute
- ``instance`` is the ``child`` instance
- ``cls`` is the ``Child`` class (we don't need it)
"""
if instance is None:
return self
# The related instance is loaded from the database and then cached
# by the field on the model instance state. It can also be pre-cached
# by the reverse accessor (ReverseOneToOneDescriptor).
try:
rel_obj = self.field.get_cached_value(instance)
except KeyError:
has_value = None not in self.field.get_local_related_value(instance)
ancestor_link = instance._meta.get_ancestor_link(self.field.model) if has_value else None
if ancestor_link and ancestor_link.is_cached(instance):
# An ancestor link will exist if this field is defined on a
# multi-table inheritance parent of the instance's class.
ancestor = ancestor_link.get_cached_value(instance)
# The value might be cached on an ancestor if the instance
# originated from walking down the inheritance chain.
rel_obj = self.field.get_cached_value(ancestor, default=None)
else:
rel_obj = None
if rel_obj is None and has_value:
rel_obj = self.get_object(instance)
remote_field = self.field.remote_field
# If this is a one-to-one relation, set the reverse accessor
# cache on the related object to the current instance to avoid
# an extra SQL query if it's accessed later on.
if not remote_field.multiple:
remote_field.set_cached_value(rel_obj, instance)
self.field.set_cached_value(instance, rel_obj)
if rel_obj is None and not self.field.null:
raise self.RelatedObjectDoesNotExist(
"%s has no %s." % (self.field.model.__name__, self.field.name)
)
else:
return rel_obj
なんとなくの理解ですが、フィールドに値がキャッシュされている場合はその値を使い、なければ get_object()
で QuerySet から値を取得しているのが読み取れます。このQuerySet から値を取得している箇所がN+1問題が起きる原因です。なので、ここを通ったときに例外を投げるようにすれば対処可能です。
おわりに
このコードを入れることで、N+1問題を減らすことができました。しかし、ForwardManyToOneDescriptor.get_object()
を書き換えてExceptionを出すようにした場合と比べると、出ないことが多いです。おそらくmockの作り方に問題がありそうですが、理由は不明です。また、prefetch_related不足など、他のパフォーマンス問題には適用できていません。
ただ、これを適用することで多くのN+1問題が検知できました。試してみてはどうでしょうか。
Discussion