🌊

前職での学び #2 ─ライブラリのソースコードを読む─

2022/06/14に公開

はじめに

前職での経験を踏まえて,自分なりの学びを共有してみたいと思います.前回 に引き続き第二弾です.

今回は,既存のライブラリのソースコードを読むことによって,より柔軟に機能を拡張するやり方をご紹介します.初心者にとってライブラリは ドキュメントに記載されている使い方を覚えることにのみ終始し,いわばそれが「公理」になって,それ以上深掘りできていない状態になることが多い ように思います.しかしながら,ライブラリにも実装はある わけで,そのような前提をさらに掘り進めていくことで,そのメカニズムを理解したり,より柔軟に利用できたりできる と思います.

このような「自明と思われる前提をさらに疑う」視点はプログラミング上達のみならず,モノの考え方において重要であるように思えてならない ので,まとめてみたいと思います.

以下では,そのような視点を利用した工夫・実装についてご紹介します.

Entity (QuerySet)

クリーンアーキテクチャーで言う Entity,Django でいう Model において,いわゆる論理削除を行いたいとします.論理削除とはカラムとして削除されたか否かのフラグを保持しておき,それが False であるようなものを存在するデータとして扱います.既存の Django Model をうまく拡張して,そのような要望に対応したいときに,以下のように実装しました.objects など,論理削除を意識せず普通の Django のように使えるようにするためです.

mbook_model.py
class MBookQuerySet(QuerySet):

    def get_by_id(self, id_: str) -> Optional['MBook']:
        try:
            return self.get(id=id_)
        except MBook.DoesNotExist:
            return None

    def filter_id_in(self, id_list: List[str]) -> 'MBookQuerySet':
        return self.filter(id__in=id_list)

    def filter_eq_id(self, id_: str) -> 'MBookQuerySet':
        return self.filter(id=id_)


class MBook(BaseModel):
    '''
    掲載単行本
    '''
    
    series = models.CharField(
        max_length=8192,
        choices=BookSeriesEnum.choices()
    )
    volume = models.CharField(max_length=8192)
    mgadgets = models.ManyToManyField(
        'dorapi.MGadget',
        related_name='mgadgets',
        related_query_name='mgadget',
        through='dorapi.GadgetBook',
    )

    objects = MBookQuerySet.as_soft_manager()
    object_all = MBookQuerySet.as_manager()

Django Model と QuerySet を一つのファイルにまとめています.基本的に,ORM による処理は MBookQuerySet で定義されたメソッドだけを用いて行います (必要な処理があれば都度書き足していく).QuerySet の意味が命名により明示できる上に,get 等存在しないことがありうる場合の例外処理も隠蔽できます.

ここでの QuerySet は従来の Django の QuerySet を拡張したものです:

queryset.py
from django.db import models
from datetime import datetime


class QuerySet(models.QuerySet):

    def as_soft_manager(cls):
        from .manager import SoftManager
        manager = SoftManager.from_queryset(cls)()
        manager._built_with_as_manager = True
        return manager

    as_soft_manager.queryset_only = True
    as_soft_manager = classmethod(as_soft_manager)  # type: ignore

    def delete(self):
        deleted_at = datetime.now()
        objs = list(self)
        for obj in objs:
            obj.deleted_at = deleted_at

        self.bulk_update(objs, ['deleted_at'])
        return objs

as_soft_manager あたりの処理は,django.db.models.QuerySet の実装から拝借したものです:

query.py
def as_manager(cls):
    # Address the circular dependency between `Queryset` and `Manager`.
    from django.db.models.manager import Manager

    manager = Manager.from_queryset(cls)()
    manager._built_with_as_manager = True
    return manager

as_manager.queryset_only = True
as_manager = classmethod(as_manager)

delete に関しては実際に削除するのではなく,deleted_at フィールドを更新し論理削除としています:

manager.py
class SoftManager(models.Manager):

    def get_queryset(self):
        return self._queryset_class(model=self.model, using=self._db, hints=self._hints).filter(deleted_at=None)

BaseModel でも同様です:

base_model.py
from datetime import datetime
import uuid

from django.db import models


class BaseModel(models.Model):
    id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
    create_date = models.DateTimeField(auto_now_add=True)
    update_date = models.DateTimeField(auto_now=True)
    deleted_at = models.DateTimeField(default=None, null=True, blank=True)

    class Meta:
        abstract = True

    def delete(self):
        self.deleted_at = datetime.now()
        self.save()

    def hard_delete(self):
        super().delete()

これらの拡張は,「ライブラリのソースコードを読む」という発想がなければできません.

Swagger

Django REST framework のコードから Swagger を自動生成することができますが,例に挙げているアプリケーションではそこもカスタマイズしています.クエリパラメータや認証ヘッダーなどの表示を追加するために,こちらも内部実装を参考に拡張しています.アプリケーションで独自に追加した機能に対応させることで,認証を通し,クエリパラメータ等も UI から指定して API を容易にテストできるようにするためです.

custom_schema.py
class CustomSchema(AutoSchema):

    def get_tags(self, path, method):
        tag = getattr(self.view, 'TAG', [])

        # super().get_tags(path, method)
        return tag if len(tag) else ['/'.join(path.split('/')[1:3]) + '/']

    def get_operation(self, path, method):
        operation = super().get_operation(path, method)
        operation['security'] = self._get_securities(path, method)

        return operation

    def _get_securities(self, path, method):
        view = self.view

        schemas = []
        
        securities = [{key: [] for key in s.keys()} for s in schemas]
        for auth_class in view.authentication_classes:
            if hasattr(auth_class, 'security_schema'):
                securities.append({key: []
                                  for key in auth_class.security_schema.keys()})

        return securities

    def get_filter_parameters(self, path, method):

        if not self.allows_filters(path, method):
            return []

        if hasattr(self.view, 'filter_set_dict'):
            filter_set_class = self.view.get_filter_set_class()
            return filter_set_class.get_schema_operation_parameters()
        else:
            return []

    # def get_request_serializer(self, path, method):
    #     return self._get_serializer(
    #         path, method, 'request')

    # def get_response_serializer(self, path, method):
    #     return self._get_serializer(
    #         path, method, 'response')

    def get_request_body(self, path, method):
        """
        Need to override get_serializer
        """
        if method not in ('PUT', 'PATCH', 'POST'):
            return {}

        self.request_media_types = self.map_parsers(path, method)

        serializer = self._get_serializer(
            path, method, 'request')

        if not isinstance(serializer, serializers.Serializer):
            item_schema = {}
        else:
            item_schema = self._get_reference(serializer)
        return {
            'content': {
                ct: {'schema': item_schema}
                for ct in self.request_media_types
            }
        }

    def get_responses(self, path, method):
        """
        Need to override get_serializer
        """
        if method == 'DELETE':
            return {
                '204': {
                    'description': ''
                }
            }

        self.response_media_types = self.map_renderers(path, method)

        serializer = self._get_serializer(
            path, method, 'response')

        if not isinstance(serializer, serializers.Serializer):
            item_schema = {}
        else:
            item_schema = self._get_reference(serializer)

        if is_list_view(path, method, self.view):
            response_schema = {
                'type': 'array',
                'items': item_schema,
            }
            paginator = self.get_paginator()
            if paginator:
                response_schema = paginator.get_paginated_response_schema(
                    response_schema)
        else:
            response_schema = item_schema
        status_code = '201' if method == 'POST' else '200'

        return {
            status_code: {
                'content': {
                    ct: {'schema': response_schema}
                    for ct in self.response_media_types
                },
                # description is a mandatory property,
                # https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#responseObject
                # TODO: put something meaningful into it
                'description': ""
            }
        }

    def _get_serializer(self, path, method, type_=''):
        view = self.view

        # BaseViewSet checker
        if not hasattr(view, 'serializer_alias_dict'):
            return super().get_serializer(path, method)
        else:
            try:
                serializer_class = view.get_serializer_class(type_)
                return serializer_class()
            except exceptions.APIException:
                warnings.warn('{}.get_serializer() raised an exception during '
                              'schema generation. Serializer fields will not be '
                              'generated for {} {}.'
                              .format(view.__class__.__name__, method, path))
                return None

    def get_operation_id(self, path, method):
        """
        Needed to avoid warning for operation ID
        """
        method_name = getattr(self.view, 'action', method.lower())
        if is_list_view(path, method, self.view):
            action = 'list'
        elif method_name not in self.method_mapping:
            action = self._to_camel_case(method_name)
        else:
            action = self.method_mapping[method.lower()]

        name = self.view.__class__.__name__

        return action + name

    def map_field(self, field):

        if isinstance(field, serializers.BaseField):
            return field.get_default_schema()

        return super().map_field(field)

    def map_serializer(self, serializer):
        """
        Needed to override
        """
        # Assuming we have a valid serializer instance.
        required = []
        properties = {}
        

        for field in serializer.fields.values():
            if isinstance(field, serializers.HiddenField):
                continue

            if field.required:
                required.append(field.field_name)

            schema: Dict[str, Any] = self.map_field(field)
            if field.read_only:
                schema['readOnly'] = True
            if field.write_only:
                schema['writeOnly'] = True
            if field.allow_null:
                schema['nullable'] = True

            super_class_default_cond = field.default is not None and field.default != empty and not callable(field.default)
            custom_cond = not isinstance(field, serializers.BaseField)

            if super_class_default_cond and custom_cond:
                schema['default'] = field.default
            if field.help_text:
                schema['description'] = str(field.help_text)
            self.map_field_validators(field, schema)

            properties[field.field_name] = schema

        result = {
            'type': 'object',
            'properties': properties
        }
        if required:
            result['required'] = required

        return result

これらのコードの大部分は rest_framework.schemas.openapi.AutoSchema から拝借しています:

openapi.py
class AutoSchema(ViewInspector):

    def get_operation(self, path, method):
        operation = {}

        operation['operationId'] = self.get_operation_id(path, method)
        operation['description'] = self.get_description(path, method)

        parameters = []
        parameters += self.get_path_parameters(path, method)
        parameters += self.get_pagination_parameters(path, method)
        parameters += self.get_filter_parameters(path, method)
        operation['parameters'] = parameters

        request_body = self.get_request_body(path, method)
        if request_body:
            operation['requestBody'] = request_body
        operation['responses'] = self.get_responses(path, method)
        operation['tags'] = self.get_tags(path, method)

        return operation

    def _to_camel_case(self, snake_str):
        components = snake_str.split('_')
        # We capitalize the first letter of each component except the first one
        # with the 'title' method and join them together.
        return components[0] + ''.join(x.title() for x in components[1:])

    def get_filter_parameters(self, path, method):
        if not self.allows_filters(path, method):
            return []
        parameters = []
        for filter_backend in self.view.filter_backends:
            parameters += filter_backend().get_schema_operation_parameters(self.view)
        return parameters

    def get_request_body(self, path, method):
        if method not in ('PUT', 'PATCH', 'POST'):
            return {}

        self.request_media_types = self.map_parsers(path, method)

        serializer = self.get_request_serializer(path, method)

        if not isinstance(serializer, serializers.Serializer):
            item_schema = {}
        else:
            item_schema = self.get_reference(serializer)

        return {
            'content': {
                ct: {'schema': item_schema}
                for ct in self.request_media_types
            }
        }

    def get_responses(self, path, method):
        if method == 'DELETE':
            return {
                '204': {
                    'description': ''
                }
            }

        self.response_media_types = self.map_renderers(path, method)

        serializer = self.get_response_serializer(path, method)

        if not isinstance(serializer, serializers.Serializer):
            item_schema = {}
        else:
            item_schema = self.get_reference(serializer)

        if is_list_view(path, method, self.view):
            response_schema = {
                'type': 'array',
                'items': item_schema,
            }
            paginator = self.get_paginator()
            if paginator:
                response_schema = paginator.get_paginated_response_schema(response_schema)
        else:
            response_schema = item_schema
        status_code = '201' if method == 'POST' else '200'
        return {
            status_code: {
                'content': {
                    ct: {'schema': response_schema}
                    for ct in self.response_media_types
                },
                # description is a mandatory property,
                # https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#responseObject
                # TODO: put something meaningful into it
                'description': ""
            }
        }

    def get_serializer(self, path, method):
        view = self.view

        if not hasattr(view, 'get_serializer'):
            return None

        try:
            return view.get_serializer()
        except exceptions.APIException:
            warnings.warn('{}.get_serializer() raised an exception during '
                          'schema generation. Serializer fields will not be '
                          'generated for {} {}.'
                          .format(view.__class__.__name__, method, path))
            return None

    
    def get_operation_id(self, path, method):
        """
        Compute an operation ID from the view type and get_operation_id_base method.
        """
        method_name = getattr(self.view, 'action', method.lower())
        if is_list_view(path, method, self.view):
            action = 'list'
        elif method_name not in self.method_mapping:
            action = self._to_camel_case(method_name)
        else:
            action = self.method_mapping[method.lower()]

        name = self.get_operation_id_base(path, method, action)

        return action + name

    def map_field(self, field):

        # Nested Serializers, `many` or not.
        if isinstance(field, serializers.ListSerializer):
            return {
                'type': 'array',
                'items': self.map_serializer(field.child)
            }
        if isinstance(field, serializers.Serializer):
            data = self.map_serializer(field)
            data['type'] = 'object'
            return data

        # Related fields.
        if isinstance(field, serializers.ManyRelatedField):
            return {
                'type': 'array',
                'items': self.map_field(field.child_relation)
            }
        if isinstance(field, serializers.PrimaryKeyRelatedField):
            model = getattr(field.queryset, 'model', None)
            if model is not None:
                model_field = model._meta.pk
                if isinstance(model_field, models.AutoField):
                    return {'type': 'integer'}

        # ChoiceFields (single and multiple).
        # Q:
        # - Is 'type' required?
        # - can we determine the TYPE of a choicefield?
        if isinstance(field, serializers.MultipleChoiceField):
            return {
                'type': 'array',
                'items': self.map_choicefield(field)
            }

        if isinstance(field, serializers.ChoiceField):
            return self.map_choicefield(field)

        # ListField.
        if isinstance(field, serializers.ListField):
            mapping = {
                'type': 'array',
                'items': {},
            }
            if not isinstance(field.child, _UnvalidatedField):
                mapping['items'] = self.map_field(field.child)
            return mapping

        # DateField and DateTimeField type is string
        if isinstance(field, serializers.DateField):
            return {
                'type': 'string',
                'format': 'date',
            }

        if isinstance(field, serializers.DateTimeField):
            return {
                'type': 'string',
                'format': 'date-time',
            }

        # "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification."
        # see: https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types
        # see also: https://swagger.io/docs/specification/data-models/data-types/#string
        if isinstance(field, serializers.EmailField):
            return {
                'type': 'string',
                'format': 'email'
            }

        if isinstance(field, serializers.URLField):
            return {
                'type': 'string',
                'format': 'uri'
            }

        if isinstance(field, serializers.UUIDField):
            return {
                'type': 'string',
                'format': 'uuid'
            }

        if isinstance(field, serializers.IPAddressField):
            content = {
                'type': 'string',
            }
            if field.protocol != 'both':
                content['format'] = field.protocol
            return content

        if isinstance(field, serializers.DecimalField):
            if getattr(field, 'coerce_to_string', api_settings.COERCE_DECIMAL_TO_STRING):
                content = {
                    'type': 'string',
                    'format': 'decimal',
                }
            else:
                content = {
                    'type': 'number'
                }

            if field.decimal_places:
                content['multipleOf'] = float('.' + (field.decimal_places - 1) * '0' + '1')
            if field.max_whole_digits:
                content['maximum'] = int(field.max_whole_digits * '9') + 1
                content['minimum'] = -content['maximum']
            self._map_min_max(field, content)
            return content

        if isinstance(field, serializers.FloatField):
            content = {
                'type': 'number',
            }
            self._map_min_max(field, content)
            return content

        if isinstance(field, serializers.IntegerField):
            content = {
                'type': 'integer'
            }
            self._map_min_max(field, content)
            # 2147483647 is max for int32_size, so we use int64 for format
            if int(content.get('maximum', 0)) > 2147483647 or int(content.get('minimum', 0)) > 2147483647:
                content['format'] = 'int64'
            return content

        if isinstance(field, serializers.FileField):
            return {
                'type': 'string',
                'format': 'binary'
            }

        # Simplest cases, default to 'string' type:
        FIELD_CLASS_SCHEMA_TYPE = {
            serializers.BooleanField: 'boolean',
            serializers.JSONField: 'object',
            serializers.DictField: 'object',
            serializers.HStoreField: 'object',
        }
        return {'type': FIELD_CLASS_SCHEMA_TYPE.get(field.__class__, 'string')}

    def map_serializer(self, serializer):
        # Assuming we have a valid serializer instance.
        required = []
        properties = {}

        for field in serializer.fields.values():
            if isinstance(field, serializers.HiddenField):
                continue

            if field.required:
                required.append(field.field_name)

            schema = self.map_field(field)
            if field.read_only:
                schema['readOnly'] = True
            if field.write_only:
                schema['writeOnly'] = True
            if field.allow_null:
                schema['nullable'] = True
            if field.default is not None and field.default != empty and not callable(field.default):
                schema['default'] = field.default
            if field.help_text:
                schema['description'] = str(field.help_text)
            self.map_field_validators(field, schema)

            properties[field.field_name] = schema

        result = {
            'type': 'object',
            'properties': properties
        }
        if required:
            result['required'] = required

        return result

view.authentication_classes に属する security_schema への対応と,FilterSet (URL の ? 以降のクエリに相当) への対応を追加で実装していることがわかると思います.

security_schema は例えば次のような形状を想定しています:

security_schema = {
    'BearerAuth': {
        'type': 'http',
        'scheme': 'bearer',
    }
}

これらの形状は openapi の仕様に従っています:
https://swagger.io/specification/ (Operation Object / Security Requirement Object の項参照)

FilterSet については,

base_filterset.py
class BaseFilterSet(metaclass=FilterSetMetaclass):

    @classmethod
    def get_schema_operation_parameters(cls):

        filters = []

        for k, f in getattr(cls, '_filters'):  # nopa
            filters += f.get_schema(k)

        return filters

となっています (一部中略).fFilter を継承していて,

base_filter.py
@dataclass
class Filter(ABC):

    def get_schema(self, key):
        return [
            {
                'name': key,
                'required': False,
                'in': 'query',
                'description': self.get_filter_desc(),
                'schema': {
                    'type': 'string',
                },
            },
        ]

となっています (一部中略).これも openapi の仕様に沿っています.
https://swagger.io/specification/ (Components Object / Parameter Object の項参照)

SchemaGenerator も同じです.

schema_generator.py
class SchemaGenerator(openapi.SchemaGenerator):
    """
    SchemaGenerator
    https://github.com/encode/django-rest-framework/blob/master/rest_framework/schemas/openapi.py
    """

    def get_schema(self, request=None, public=False):
        schema = super().get_schema(request, public)

        schema['components']['securitySchemes'] = self.get_secutiry_schemes()

        return schema

    def get_secutiry_schemes(self, request=None, public=False):
        schemes = {}

        paths, view_endpoints = self._get_paths_and_endpoints(request)
        if not paths:
            return schemes

        self._set_security_schemes_from_views(view_endpoints, schemes)
        self._set_security_schemes_from_headers(schemes)

        return schemes

    def _set_security_schemes_from_headers(self, schemes):
        pass

    def _set_security_schemes_from_views(self, view_endpoints, schemes):

        for _, _, view in view_endpoints:
            for auth_class in view.authentication_classes:
                if hasattr(auth_class, 'security_schema'):
                    schemes.update(auth_class.security_schema)

これも同じく rest_framework.schemas.openapi を拡張しています.

openapi.py
class SchemaGenerator(BaseSchemaGenerator):

    def get_schema(self, request=None, public=False):
        """
        Generate a OpenAPI schema.
        """
        self._initialise_endpoints()
        components_schemas = {}

        # Iterate endpoints generating per method path operations.
        paths = {}
        _, view_endpoints = self._get_paths_and_endpoints(None if public else request)
        for path, method, view in view_endpoints:
            if not self.has_view_permissions(path, method, view):
                continue

            operation = view.schema.get_operation(path, method)
            components = view.schema.get_components(path, method)
            for k in components.keys():
                if k not in components_schemas:
                    continue
                if components_schemas[k] == components[k]:
                    continue
                warnings.warn('Schema component "{}" has been overriden with a different value.'.format(k))

            components_schemas.update(components)

            # Normalise path for any provided mount url.
            if path.startswith('/'):
                path = path[1:]
            path = urljoin(self.url or '/', path)

            paths.setdefault(path, {})
            paths[path][method.lower()] = operation

        self.check_duplicate_operation_id(paths)

        # Compile final schema.
        schema = {
            'openapi': '3.0.2',
            'info': self.get_info(),
            'paths': paths,
        }

        if len(components_schemas) > 0:
            schema['components'] = {
                'schemas': components_schemas
            }

        return schema

これらも同様に openapi の仕様に従っています:
https://swagger.io/specification/ (Components Object / Security Scheme Object の項参照)

終わりに

ライブラリの内部実装を読むことで,想定された使い方にとどまらない柔軟な拡張が可能になります.このような方法はいわゆる内容結合であり,ライブラリの変更されうる内部実装に依存するところが大きいので,大腕を振って推奨されるものではないのかもしれませんが,ブラックボックス的に想定された使い方だけで利用する以上の恩恵を受けられる可能性を秘めています.「中身を気にする」という発想は決して自明なものではないような気がしたので,共有させていただきました.

GitHubで編集を提案

Discussion