前職での学び #2 ─ライブラリのソースコードを読む─
はじめに
前職での経験を踏まえて,自分なりの学びを共有してみたいと思います.前回 に引き続き第二弾です.
今回は,既存のライブラリのソースコードを読むことによって,より柔軟に機能を拡張するやり方をご紹介します.初心者にとってライブラリは ドキュメントに記載されている使い方を覚えることにのみ終始し,いわばそれが「公理」になって,それ以上深掘りできていない状態になることが多い ように思います.しかしながら,ライブラリにも実装はある わけで,そのような前提をさらに掘り進めていくことで,そのメカニズムを理解したり,より柔軟に利用できたりできる と思います.
このような「自明と思われる前提をさらに疑う」視点はプログラミング上達のみならず,モノの考え方において重要であるように思えてならない ので,まとめてみたいと思います.
以下では,そのような視点を利用した工夫・実装についてご紹介します.
Entity (QuerySet)
クリーンアーキテクチャーで言う Entity,Django でいう Model において,いわゆる論理削除を行いたいとします.論理削除とはカラムとして削除されたか否かのフラグを保持しておき,それが False
であるようなものを存在するデータとして扱います.既存の Django Model をうまく拡張して,そのような要望に対応したいときに,以下のように実装しました.objects
など,論理削除を意識せず普通の Django のように使えるようにするためです.
class MBookQuerySet(QuerySet):
def get_by_id(self, id_: str) -> Optional['MBook']:
try:
return self.get(id=id_)
except MBook.DoesNotExist:
return None
def filter_id_in(self, id_list: List[str]) -> 'MBookQuerySet':
return self.filter(id__in=id_list)
def filter_eq_id(self, id_: str) -> 'MBookQuerySet':
return self.filter(id=id_)
class MBook(BaseModel):
'''
掲載単行本
'''
series = models.CharField(
max_length=8192,
choices=BookSeriesEnum.choices()
)
volume = models.CharField(max_length=8192)
mgadgets = models.ManyToManyField(
'dorapi.MGadget',
related_name='mgadgets',
related_query_name='mgadget',
through='dorapi.GadgetBook',
)
objects = MBookQuerySet.as_soft_manager()
object_all = MBookQuerySet.as_manager()
Django Model と QuerySet
を一つのファイルにまとめています.基本的に,ORM による処理は MBookQuerySet
で定義されたメソッドだけを用いて行います (必要な処理があれば都度書き足していく).QuerySet の意味が命名により明示できる上に,get
等存在しないことがありうる場合の例外処理も隠蔽できます.
ここでの QuerySet
は従来の Django の QuerySet
を拡張したものです:
from django.db import models
from datetime import datetime
class QuerySet(models.QuerySet):
def as_soft_manager(cls):
from .manager import SoftManager
manager = SoftManager.from_queryset(cls)()
manager._built_with_as_manager = True
return manager
as_soft_manager.queryset_only = True
as_soft_manager = classmethod(as_soft_manager) # type: ignore
def delete(self):
deleted_at = datetime.now()
objs = list(self)
for obj in objs:
obj.deleted_at = deleted_at
self.bulk_update(objs, ['deleted_at'])
return objs
as_soft_manager
あたりの処理は,django.db.models.QuerySet
の実装から拝借したものです:
def as_manager(cls):
# Address the circular dependency between `Queryset` and `Manager`.
from django.db.models.manager import Manager
manager = Manager.from_queryset(cls)()
manager._built_with_as_manager = True
return manager
as_manager.queryset_only = True
as_manager = classmethod(as_manager)
delete
に関しては実際に削除するのではなく,deleted_at
フィールドを更新し論理削除としています:
class SoftManager(models.Manager):
def get_queryset(self):
return self._queryset_class(model=self.model, using=self._db, hints=self._hints).filter(deleted_at=None)
BaseModel
でも同様です:
from datetime import datetime
import uuid
from django.db import models
class BaseModel(models.Model):
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
create_date = models.DateTimeField(auto_now_add=True)
update_date = models.DateTimeField(auto_now=True)
deleted_at = models.DateTimeField(default=None, null=True, blank=True)
class Meta:
abstract = True
def delete(self):
self.deleted_at = datetime.now()
self.save()
def hard_delete(self):
super().delete()
これらの拡張は,「ライブラリのソースコードを読む」という発想がなければできません.
Swagger
Django REST framework のコードから Swagger を自動生成することができますが,例に挙げているアプリケーションではそこもカスタマイズしています.クエリパラメータや認証ヘッダーなどの表示を追加するために,こちらも内部実装を参考に拡張しています.アプリケーションで独自に追加した機能に対応させることで,認証を通し,クエリパラメータ等も UI から指定して API を容易にテストできるようにするためです.
class CustomSchema(AutoSchema):
def get_tags(self, path, method):
tag = getattr(self.view, 'TAG', [])
# super().get_tags(path, method)
return tag if len(tag) else ['/'.join(path.split('/')[1:3]) + '/']
def get_operation(self, path, method):
operation = super().get_operation(path, method)
operation['security'] = self._get_securities(path, method)
return operation
def _get_securities(self, path, method):
view = self.view
schemas = []
securities = [{key: [] for key in s.keys()} for s in schemas]
for auth_class in view.authentication_classes:
if hasattr(auth_class, 'security_schema'):
securities.append({key: []
for key in auth_class.security_schema.keys()})
return securities
def get_filter_parameters(self, path, method):
if not self.allows_filters(path, method):
return []
if hasattr(self.view, 'filter_set_dict'):
filter_set_class = self.view.get_filter_set_class()
return filter_set_class.get_schema_operation_parameters()
else:
return []
# def get_request_serializer(self, path, method):
# return self._get_serializer(
# path, method, 'request')
# def get_response_serializer(self, path, method):
# return self._get_serializer(
# path, method, 'response')
def get_request_body(self, path, method):
"""
Need to override get_serializer
"""
if method not in ('PUT', 'PATCH', 'POST'):
return {}
self.request_media_types = self.map_parsers(path, method)
serializer = self._get_serializer(
path, method, 'request')
if not isinstance(serializer, serializers.Serializer):
item_schema = {}
else:
item_schema = self._get_reference(serializer)
return {
'content': {
ct: {'schema': item_schema}
for ct in self.request_media_types
}
}
def get_responses(self, path, method):
"""
Need to override get_serializer
"""
if method == 'DELETE':
return {
'204': {
'description': ''
}
}
self.response_media_types = self.map_renderers(path, method)
serializer = self._get_serializer(
path, method, 'response')
if not isinstance(serializer, serializers.Serializer):
item_schema = {}
else:
item_schema = self._get_reference(serializer)
if is_list_view(path, method, self.view):
response_schema = {
'type': 'array',
'items': item_schema,
}
paginator = self.get_paginator()
if paginator:
response_schema = paginator.get_paginated_response_schema(
response_schema)
else:
response_schema = item_schema
status_code = '201' if method == 'POST' else '200'
return {
status_code: {
'content': {
ct: {'schema': response_schema}
for ct in self.response_media_types
},
# description is a mandatory property,
# https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#responseObject
# TODO: put something meaningful into it
'description': ""
}
}
def _get_serializer(self, path, method, type_=''):
view = self.view
# BaseViewSet checker
if not hasattr(view, 'serializer_alias_dict'):
return super().get_serializer(path, method)
else:
try:
serializer_class = view.get_serializer_class(type_)
return serializer_class()
except exceptions.APIException:
warnings.warn('{}.get_serializer() raised an exception during '
'schema generation. Serializer fields will not be '
'generated for {} {}.'
.format(view.__class__.__name__, method, path))
return None
def get_operation_id(self, path, method):
"""
Needed to avoid warning for operation ID
"""
method_name = getattr(self.view, 'action', method.lower())
if is_list_view(path, method, self.view):
action = 'list'
elif method_name not in self.method_mapping:
action = self._to_camel_case(method_name)
else:
action = self.method_mapping[method.lower()]
name = self.view.__class__.__name__
return action + name
def map_field(self, field):
if isinstance(field, serializers.BaseField):
return field.get_default_schema()
return super().map_field(field)
def map_serializer(self, serializer):
"""
Needed to override
"""
# Assuming we have a valid serializer instance.
required = []
properties = {}
for field in serializer.fields.values():
if isinstance(field, serializers.HiddenField):
continue
if field.required:
required.append(field.field_name)
schema: Dict[str, Any] = self.map_field(field)
if field.read_only:
schema['readOnly'] = True
if field.write_only:
schema['writeOnly'] = True
if field.allow_null:
schema['nullable'] = True
super_class_default_cond = field.default is not None and field.default != empty and not callable(field.default)
custom_cond = not isinstance(field, serializers.BaseField)
if super_class_default_cond and custom_cond:
schema['default'] = field.default
if field.help_text:
schema['description'] = str(field.help_text)
self.map_field_validators(field, schema)
properties[field.field_name] = schema
result = {
'type': 'object',
'properties': properties
}
if required:
result['required'] = required
return result
これらのコードの大部分は rest_framework.schemas.openapi.AutoSchema
から拝借しています:
class AutoSchema(ViewInspector):
def get_operation(self, path, method):
operation = {}
operation['operationId'] = self.get_operation_id(path, method)
operation['description'] = self.get_description(path, method)
parameters = []
parameters += self.get_path_parameters(path, method)
parameters += self.get_pagination_parameters(path, method)
parameters += self.get_filter_parameters(path, method)
operation['parameters'] = parameters
request_body = self.get_request_body(path, method)
if request_body:
operation['requestBody'] = request_body
operation['responses'] = self.get_responses(path, method)
operation['tags'] = self.get_tags(path, method)
return operation
def _to_camel_case(self, snake_str):
components = snake_str.split('_')
# We capitalize the first letter of each component except the first one
# with the 'title' method and join them together.
return components[0] + ''.join(x.title() for x in components[1:])
def get_filter_parameters(self, path, method):
if not self.allows_filters(path, method):
return []
parameters = []
for filter_backend in self.view.filter_backends:
parameters += filter_backend().get_schema_operation_parameters(self.view)
return parameters
def get_request_body(self, path, method):
if method not in ('PUT', 'PATCH', 'POST'):
return {}
self.request_media_types = self.map_parsers(path, method)
serializer = self.get_request_serializer(path, method)
if not isinstance(serializer, serializers.Serializer):
item_schema = {}
else:
item_schema = self.get_reference(serializer)
return {
'content': {
ct: {'schema': item_schema}
for ct in self.request_media_types
}
}
def get_responses(self, path, method):
if method == 'DELETE':
return {
'204': {
'description': ''
}
}
self.response_media_types = self.map_renderers(path, method)
serializer = self.get_response_serializer(path, method)
if not isinstance(serializer, serializers.Serializer):
item_schema = {}
else:
item_schema = self.get_reference(serializer)
if is_list_view(path, method, self.view):
response_schema = {
'type': 'array',
'items': item_schema,
}
paginator = self.get_paginator()
if paginator:
response_schema = paginator.get_paginated_response_schema(response_schema)
else:
response_schema = item_schema
status_code = '201' if method == 'POST' else '200'
return {
status_code: {
'content': {
ct: {'schema': response_schema}
for ct in self.response_media_types
},
# description is a mandatory property,
# https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#responseObject
# TODO: put something meaningful into it
'description': ""
}
}
def get_serializer(self, path, method):
view = self.view
if not hasattr(view, 'get_serializer'):
return None
try:
return view.get_serializer()
except exceptions.APIException:
warnings.warn('{}.get_serializer() raised an exception during '
'schema generation. Serializer fields will not be '
'generated for {} {}.'
.format(view.__class__.__name__, method, path))
return None
def get_operation_id(self, path, method):
"""
Compute an operation ID from the view type and get_operation_id_base method.
"""
method_name = getattr(self.view, 'action', method.lower())
if is_list_view(path, method, self.view):
action = 'list'
elif method_name not in self.method_mapping:
action = self._to_camel_case(method_name)
else:
action = self.method_mapping[method.lower()]
name = self.get_operation_id_base(path, method, action)
return action + name
def map_field(self, field):
# Nested Serializers, `many` or not.
if isinstance(field, serializers.ListSerializer):
return {
'type': 'array',
'items': self.map_serializer(field.child)
}
if isinstance(field, serializers.Serializer):
data = self.map_serializer(field)
data['type'] = 'object'
return data
# Related fields.
if isinstance(field, serializers.ManyRelatedField):
return {
'type': 'array',
'items': self.map_field(field.child_relation)
}
if isinstance(field, serializers.PrimaryKeyRelatedField):
model = getattr(field.queryset, 'model', None)
if model is not None:
model_field = model._meta.pk
if isinstance(model_field, models.AutoField):
return {'type': 'integer'}
# ChoiceFields (single and multiple).
# Q:
# - Is 'type' required?
# - can we determine the TYPE of a choicefield?
if isinstance(field, serializers.MultipleChoiceField):
return {
'type': 'array',
'items': self.map_choicefield(field)
}
if isinstance(field, serializers.ChoiceField):
return self.map_choicefield(field)
# ListField.
if isinstance(field, serializers.ListField):
mapping = {
'type': 'array',
'items': {},
}
if not isinstance(field.child, _UnvalidatedField):
mapping['items'] = self.map_field(field.child)
return mapping
# DateField and DateTimeField type is string
if isinstance(field, serializers.DateField):
return {
'type': 'string',
'format': 'date',
}
if isinstance(field, serializers.DateTimeField):
return {
'type': 'string',
'format': 'date-time',
}
# "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification."
# see: https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types
# see also: https://swagger.io/docs/specification/data-models/data-types/#string
if isinstance(field, serializers.EmailField):
return {
'type': 'string',
'format': 'email'
}
if isinstance(field, serializers.URLField):
return {
'type': 'string',
'format': 'uri'
}
if isinstance(field, serializers.UUIDField):
return {
'type': 'string',
'format': 'uuid'
}
if isinstance(field, serializers.IPAddressField):
content = {
'type': 'string',
}
if field.protocol != 'both':
content['format'] = field.protocol
return content
if isinstance(field, serializers.DecimalField):
if getattr(field, 'coerce_to_string', api_settings.COERCE_DECIMAL_TO_STRING):
content = {
'type': 'string',
'format': 'decimal',
}
else:
content = {
'type': 'number'
}
if field.decimal_places:
content['multipleOf'] = float('.' + (field.decimal_places - 1) * '0' + '1')
if field.max_whole_digits:
content['maximum'] = int(field.max_whole_digits * '9') + 1
content['minimum'] = -content['maximum']
self._map_min_max(field, content)
return content
if isinstance(field, serializers.FloatField):
content = {
'type': 'number',
}
self._map_min_max(field, content)
return content
if isinstance(field, serializers.IntegerField):
content = {
'type': 'integer'
}
self._map_min_max(field, content)
# 2147483647 is max for int32_size, so we use int64 for format
if int(content.get('maximum', 0)) > 2147483647 or int(content.get('minimum', 0)) > 2147483647:
content['format'] = 'int64'
return content
if isinstance(field, serializers.FileField):
return {
'type': 'string',
'format': 'binary'
}
# Simplest cases, default to 'string' type:
FIELD_CLASS_SCHEMA_TYPE = {
serializers.BooleanField: 'boolean',
serializers.JSONField: 'object',
serializers.DictField: 'object',
serializers.HStoreField: 'object',
}
return {'type': FIELD_CLASS_SCHEMA_TYPE.get(field.__class__, 'string')}
def map_serializer(self, serializer):
# Assuming we have a valid serializer instance.
required = []
properties = {}
for field in serializer.fields.values():
if isinstance(field, serializers.HiddenField):
continue
if field.required:
required.append(field.field_name)
schema = self.map_field(field)
if field.read_only:
schema['readOnly'] = True
if field.write_only:
schema['writeOnly'] = True
if field.allow_null:
schema['nullable'] = True
if field.default is not None and field.default != empty and not callable(field.default):
schema['default'] = field.default
if field.help_text:
schema['description'] = str(field.help_text)
self.map_field_validators(field, schema)
properties[field.field_name] = schema
result = {
'type': 'object',
'properties': properties
}
if required:
result['required'] = required
return result
view.authentication_classes
に属する security_schema
への対応と,FilterSet
(URL の ?
以降のクエリに相当) への対応を追加で実装していることがわかると思います.
security_schema
は例えば次のような形状を想定しています:
security_schema = {
'BearerAuth': {
'type': 'http',
'scheme': 'bearer',
}
}
これらの形状は openapi の仕様に従っています:
https://swagger.io/specification/ (Operation Object
/ Security Requirement Object
の項参照)
FilterSet
については,
class BaseFilterSet(metaclass=FilterSetMetaclass):
@classmethod
def get_schema_operation_parameters(cls):
filters = []
for k, f in getattr(cls, '_filters'): # nopa
filters += f.get_schema(k)
return filters
となっています (一部中略).f
は Filter
を継承していて,
@dataclass
class Filter(ABC):
def get_schema(self, key):
return [
{
'name': key,
'required': False,
'in': 'query',
'description': self.get_filter_desc(),
'schema': {
'type': 'string',
},
},
]
となっています (一部中略).これも openapi の仕様に沿っています.
https://swagger.io/specification/ (Components Object
/ Parameter Object
の項参照)
SchemaGenerator
も同じです.
class SchemaGenerator(openapi.SchemaGenerator):
"""
SchemaGenerator
https://github.com/encode/django-rest-framework/blob/master/rest_framework/schemas/openapi.py
"""
def get_schema(self, request=None, public=False):
schema = super().get_schema(request, public)
schema['components']['securitySchemes'] = self.get_secutiry_schemes()
return schema
def get_secutiry_schemes(self, request=None, public=False):
schemes = {}
paths, view_endpoints = self._get_paths_and_endpoints(request)
if not paths:
return schemes
self._set_security_schemes_from_views(view_endpoints, schemes)
self._set_security_schemes_from_headers(schemes)
return schemes
def _set_security_schemes_from_headers(self, schemes):
pass
def _set_security_schemes_from_views(self, view_endpoints, schemes):
for _, _, view in view_endpoints:
for auth_class in view.authentication_classes:
if hasattr(auth_class, 'security_schema'):
schemes.update(auth_class.security_schema)
これも同じく rest_framework.schemas.openapi
を拡張しています.
class SchemaGenerator(BaseSchemaGenerator):
def get_schema(self, request=None, public=False):
"""
Generate a OpenAPI schema.
"""
self._initialise_endpoints()
components_schemas = {}
# Iterate endpoints generating per method path operations.
paths = {}
_, view_endpoints = self._get_paths_and_endpoints(None if public else request)
for path, method, view in view_endpoints:
if not self.has_view_permissions(path, method, view):
continue
operation = view.schema.get_operation(path, method)
components = view.schema.get_components(path, method)
for k in components.keys():
if k not in components_schemas:
continue
if components_schemas[k] == components[k]:
continue
warnings.warn('Schema component "{}" has been overriden with a different value.'.format(k))
components_schemas.update(components)
# Normalise path for any provided mount url.
if path.startswith('/'):
path = path[1:]
path = urljoin(self.url or '/', path)
paths.setdefault(path, {})
paths[path][method.lower()] = operation
self.check_duplicate_operation_id(paths)
# Compile final schema.
schema = {
'openapi': '3.0.2',
'info': self.get_info(),
'paths': paths,
}
if len(components_schemas) > 0:
schema['components'] = {
'schemas': components_schemas
}
return schema
これらも同様に openapi の仕様に従っています:
https://swagger.io/specification/ (Components Object
/ Security Scheme Object
の項参照)
終わりに
ライブラリの内部実装を読むことで,想定された使い方にとどまらない柔軟な拡張が可能になります.このような方法はいわゆる内容結合であり,ライブラリの変更されうる内部実装に依存するところが大きいので,大腕を振って推奨されるものではないのかもしれませんが,ブラックボックス的に想定された使い方だけで利用する以上の恩恵を受けられる可能性を秘めています.「中身を気にする」という発想は決して自明なものではないような気がしたので,共有させていただきました.
Discussion