💡

Django ORMでN+1問題を検知する方法

2021/12/20に公開

はじめに

これは Django Advent Calendar 2021 14日目の記事です(空いていたので飛び入りで)。

ORMを使うときに注意が必要なのが、N+1問題と呼ばれるものです。これは、外部キーを使ってJOINで結合すれば1回のSQLで全て取得できるものが、コードの書き方がまずくて1回 + N回(レコード数)のSQLが発行されるという問題です。その原理から、N+1問題ではなく、1+N問題と呼んだ方がいいという話もあります。

N+1問題は1+N問題 - Qiita

この問題を解消するために、Djangoではselect_related()というQuerySetのメソッドが提供されています。これに外部キーを指定することでSQLのJOINで別テーブルの内容を取ってきます。外部キーの先の外部キーも指定可能です。

一方で、N+1問題を検知するのは難しいです。定番は Django Debug Toolbar を使う方法です。アクセスすると、次のような画面が出ます。これを使ってN+1問題を検知することができます。

これは確かに便利ですが、画面を表示しないと確認できないので、どうしても確認漏れが出てしまいます。CIで検知できればベストでしょう。

CIで検知できる方法はないか。そう考えて色々試してみた結果、簡単に組み込めるN+1検知のコードを見つけることができました。

ただ強力すぎて使いづらいかもしれませんが・・・

N+1問題を検知するコード

次のコードを入れれば、N+1問題を検知できます。

from unittest import mock

import pytest
from django.db.models.fields.related_descriptors import ForwardManyToOneDescriptor

@pytest.fixture(autouse=True)
def mock_nplusone():
    mock.patch.object(ForwardManyToOneDescriptor, "get_object", side_effect=Exception)

あるいは次のようにコンテキストマネージャとして利用する手もあります。

with mock.patch.object(ForwardManyToOneDescriptor, "get_object", side_effect=Exception):

mockの使い方が良くないのか、これでも検知できないものもあるようですが、原理的には ForwardManyToOneDescriptor.get_object() を潰すことで対応できます。

N+1問題があると、次のようなメッセージが出て落ちます。この場合、 self.project.mode の箇所でN+1問題が起きています。

apps/task/next_action/models.py:248: in current_mode
    elif self.project.mode:
venv/lib/python3.9/site-packages/django/db/models/fields/related_descriptors.py:187: in __get__
    rel_obj = self.get_object(instance)
/opt/homebrew/Cellar/python@3.9/3.9.9/Frameworks/Python.framework/Versions/3.9/lib/python3.9/unittest/mock.py:1092: in __call__
    return self._mock_call(*args, **kwargs)
/opt/homebrew/Cellar/python@3.9/3.9.9/Frameworks/Python.framework/Versions/3.9/lib/python3.9/unittest/mock.py:1096: in _mock_call
    return self._execute_mock_call(*args, **kwargs)

例えば、先ほどのエラーは次のように書いている箇所で起きていました(正確にはprintではないですが)。

setting = Settings.objects.first()
print(setting.mode)

これは setting.mode を取得する時点でクエリが走ります。そのため次のように書くといいです。

setting = Setting.objects.select_related("mode").first()
print(setting.mode)

ちょっとめんどいですが、SQL発行回数は1つ減りました。よしとしましょう。

N+1問題が検知できる理由

なぜ ForwardManyToOneDescriptor.get_object() を潰せばN+1問題が検知できるのでしょうか。それは、ForwardManyToOneDescriptor の実装にあります。

    def get_object(self, instance):
        qs = self.get_queryset(instance=instance)
        # Assuming the database enforces foreign keys, this won't fail.
        return qs.get(self.field.get_reverse_related_filter(instance))

    def __get__(self, instance, cls=None):
        """
        Get the related instance through the forward relation.
        With the example above, when getting ``child.parent``:
        - ``self`` is the descriptor managing the ``parent`` attribute
        - ``instance`` is the ``child`` instance
        - ``cls`` is the ``Child`` class (we don't need it)
        """
        if instance is None:
            return self

        # The related instance is loaded from the database and then cached
        # by the field on the model instance state. It can also be pre-cached
        # by the reverse accessor (ReverseOneToOneDescriptor).
        try:
            rel_obj = self.field.get_cached_value(instance)
        except KeyError:
            has_value = None not in self.field.get_local_related_value(instance)
            ancestor_link = instance._meta.get_ancestor_link(self.field.model) if has_value else None
            if ancestor_link and ancestor_link.is_cached(instance):
                # An ancestor link will exist if this field is defined on a
                # multi-table inheritance parent of the instance's class.
                ancestor = ancestor_link.get_cached_value(instance)
                # The value might be cached on an ancestor if the instance
                # originated from walking down the inheritance chain.
                rel_obj = self.field.get_cached_value(ancestor, default=None)
            else:
                rel_obj = None
            if rel_obj is None and has_value:
                rel_obj = self.get_object(instance)
                remote_field = self.field.remote_field
                # If this is a one-to-one relation, set the reverse accessor
                # cache on the related object to the current instance to avoid
                # an extra SQL query if it's accessed later on.
                if not remote_field.multiple:
                    remote_field.set_cached_value(rel_obj, instance)
            self.field.set_cached_value(instance, rel_obj)

        if rel_obj is None and not self.field.null:
            raise self.RelatedObjectDoesNotExist(
                "%s has no %s." % (self.field.model.__name__, self.field.name)
            )
        else:
            return rel_obj

なんとなくの理解ですが、フィールドに値がキャッシュされている場合はその値を使い、なければ get_object() で QuerySet から値を取得しているのが読み取れます。このQuerySet から値を取得している箇所がN+1問題が起きる原因です。なので、ここを通ったときに例外を投げるようにすれば対処可能です。

おわりに

このコードを入れることで、N+1問題を減らすことができました。しかし、ForwardManyToOneDescriptor.get_object() を書き換えてExceptionを出すようにした場合と比べると、出ないことが多いです。おそらくmockの作り方に問題がありそうですが、理由は不明です。また、prefetch_related不足など、他のパフォーマンス問題には適用できていません。

ただ、これを適用することで多くのN+1問題が検知できました。試してみてはどうでしょうか。

Discussion