アノテーションプラットフォームのタスク作成機能をリファクタリングした話
FastLabelでソフトウェアエンジニアをしている姉川です。
今回は、FastLabelのアノテーションツールにおける、タスクの作成機能のリファクタリングをした話をしたいと思います。
具体的なコードを記載している部分では、わかりやすさのため一部省略したり間違った使用法をしている箇所があるかもしれませんが、あらかじめご了承ください。
FastLabelのサービス概要
FastLabelとは、AIデータプラットフォームとして、アノテーションツールをはじめとした、AI開発を効率化するためのサービスを提供しています。
FastLabelの詳細な説明に関しては、こちらをご覧いただけたらと思います。
アノテーションの流れ
1. アノテーションプロジェクトの作成
- アノテーションを実施するためのプロジェクトを作成します。
- プロジェクトの種類には以下のようなものがあります。(基本的に、[データタイプ] - [アノテーションタイプ] の組み合わせでプロジェクトが定義されます。)
- 画像 - 矩形
- 画像 - セグメンテーション
- 動画 - 矩形
- 音声 - セグメンテーション
- ….
2. タスクの作成
- アノテーションは、タスクという単位で実施してきます。
- 例えば、画像: sample.pngに対して、アノテーションを実施したい場合、sample.pngをプラットフォームに取り込んで、その画像に対してアノテーションを実施するためのタスクを作成します。
3. アノテーションの実施
- 作成したタスクに対してアノテーションを実施していきます。
[本題]タスク作成機能のリファクタリング
前提知識
タスクのモデリングに関して
- プロジェクトに対して、複数のタスクが紐づきます。
- タスクに対して、複数のコンテンツが紐づきます。(1つのアノテーションタスクに複数のデータを必要とするタスクが存在します。(複数画像、マルチモーダルなど)
- コンテンツにはアノテーションを実施するデータのメタデータ情報等が保持されます。(画像のサイズ、音声の長さ、動画の長さ、データサイズなど)
例)Localでファイルをアップロードする際のタスク作成の処理フロー
モチベーション
- Batch Serverにおけるタスク作成処理に関して、大量のif分岐が発生しており、メンテナンス性が著しく低下していた。(約1000行の処理…)
- Project type × Storage typeの数だけif文が存在するイメージ。
- Project type: 画像、動画、音声、点群、マルチモーダル…
- Storage type: ローカル、S3、GCP、Azure….
- Project type × Storage typeの数だけif文が存在するイメージ。
リファクタリングの方針
リファクタリング前の処理イメージ(1つの関数に約1000行のロジックが記載されていた)
class TaskImportService:
def create(self, params):
# アップロードしたファイルや、外部連携先のデータからTaskを構築する処理
if params["storage_type"] == "local":
if params["project_type"] == "image_bbox":
...
elif params["project_type"] == "video_bbox":
...
elif params["storage_type"] == "s3":
...
# TaskのValidation処理(値ではなく、ビジネス的な側面でのValidation)
...
# Taskデータの永続化処理
...
Task作成の処理の流れとして、以下の3つの処理に分類できる。
- アップロードしたファイルや、外部連携先のデータからTaskオブジェクトを生成する処理
- TaskのValidationの処理
- Taskの永続化処理
また、1のTaskを構築するロジックや2のTaskのValidationの処理は、Storage typeによってロジックが切り替わっていることがわかる。
このような処理の特性から、「テンプレートメソッドパターン」と「ストラテジーパターン」を参考にしながら、リファクタリングを実施してみました。
リファクタリングの手順
テンプレートメソッドパターン、ストラテジーパターンを参考に、抽象に対してプログラミングを実施していく。
テンプレートメソッドパターンを参考に、処理の枠組みを定義しています。
また、終端処理として、completeメソッドを呼び出し側で呼ぶことによって、タスクの永続化処理が完了した後に実施したい処理を呼び出すことができるようになっています。
class BaseTaskImportStrategy(abc.ABC):
def __init__(self, task_subject: TaskSubject | None = None) -> None:
self._msg_code = ImportMessageCode.none
self._task_subject = task_subject
@abc.abstractmethod
def convert_to_tasks(self) -> tuple[list[Task], ImportMessageCode]:
raise NotImplementedError()
@abc.abstractmethod
def is_valid_tasks(self, tasks: list[Task]) -> bool:
raise NotImplementedError()
def is_valid(self) -> bool:
if not hasattr(self, "_tasks"):
tasks, msg_code = self.convert_to_tasks()
self._msg_code = msg_code
is_valid = self.is_valid_tasks(tasks)
if is_valid:
self._tasks = tasks
return is_valid
return True
def complete(self):
if self._task_subject:
self._task_subject.notify_observers(self.tasks)
@property
def tasks(self) -> list[Task]:
if not hasattr(self, "_tasks"):
raise AssertionError("call is_valid before get tasks")
return self._tasks
@property
def msg_code(self) -> ImportMessageCode:
if not hasattr(self, "_msg_code"):
raise AssertionError("call is_valid before get import_msg_code")
return self._msg_code
元々存在していたTaskImportServiceでTaskImportStrategyを呼び出す形に修正しました。
詳細には依存せずに、全て抽象に依存するようにしています。
class TaskImportService:
def __init__(
self,
injector: Injector,
import_strategy: BaseTaskImportStrategy,
):
self._import_strategy = import_strategy
self._task_repository = injector.get(TaskRepository)
def execute(self, params) -> TaskImportExecuteResult:
if not self._import_strategy.is_valid():
return ImportStatus.failed, 0, self._import_strategy.msg_code
tasks = self._import_strategy.tasks
self._task_repository.bulk_save(tasks)
self._import_strategy.complete()
return ImportStatus.completed, len(tasks), self._import_strategy.msg_code
BaseTaskImportStrategyの実装イメージ
また、Contentオブジェクトの生成において、Project typeごとにif文が発生していたので、ファクトリーメソッドパターンを参考にリファクタリングを実施しました。Contentオブジェクトの生成を呼び出し側と分離することで拡張性があるようにしました。
class LocalContentFactory(abc.ABC):
def __init__(
self,
local_folder_path: str,
local_content_file_uploader: LocalContentFileUploader,
):
self._file_infos = []
self._local_folder_path = local_folder_path
self._local_content_file_uploader: LocalContentFileUploader = (
local_content_file_uploader
)
def create(self, params: LocalContentCreateParams) -> list[Content]:
created_contents = self.create_contents(params)
contents = []
for content, file_key, local_file_path in created_contents:
self._file_infos.append(
{
"content": content,
"file_key": file_key,
"local_file_path": local_file_path,
}
)
contents.append(content)
return contents
@abc.abstractmethod
def create_contents(
self, params: LocalContentCreateParams
) -> list[tuple[Content, LocalFileKey, LocalFilePath]]:
raise NotImplementedError()
def upload_contents(self):
self._local_content_file_uploader.upload_contents(
self._file_infos,
)
実装イメージ
マルチモーダルプロジェクトのコンテンツオブジェクトの生成に関して
マルチモーダルプロジェクトでは、1つのタスクで、動画とCSVデータを利用するものがあります。このケースに対応するために、既存のFactoryに委譲をするような形で実装をしています。
class LocalMultiModalVideoTimeSeriesProjectContentFactory(LocalContentFactory):
def __init__(
self,
injector: Injector,
project_id: str,
local_folder_path: str,
):
self._video_factory = LocalVideoContentFactory(
injector, project_id, local_folder_path
)
self._csv_factory = LocalCSVContentFactory(
injector, project_id, local_folder_path
)
...
def create_contents(self, params: LocalContentCreateParams):
# videoとcsvデータ両面に対するValidatioin
...
if <videoの場合>:
# LocalVideoContentFactoryに処理を委譲
return self._video_factory.create_contents(params)
elif <csvの場合>:
# LocalCSVContentFactoryに処理を委譲
return self._csv_factory.create_contents(params)
else:
raise UnsupportedContentException()
まとめ
- 今回のリファクタリングを通じて、改めて、迅速に、安定的に価値を届けるためにソフトウェアがEasier to changeであることの重要性を学ぶことができて良い経験になりました。
- ETCであるためには、SOLID原則に基づく設計はとても有効で、SOLIDであるためにデザインパターンなどの設計パターンは大変参考になると感じました。
- リファクタリングをしていく中で、現状の仕様の都合上うまく抽象化できない場面も多くあり、多少無理をしたコードや、現状の仕様を調整して対応したケースも多々ありました。(実際はこの作業が一番時間がかかったところでした…)
- うまく抽象化できない場合は、そもそもの仕様がわかりずらい可能性があるので、理想的には、仕様を決定するタイミングで、技術的負債になりづらいような仕様で調整することが大切だと再認識しました。
最後までお読みいただきありがとうございました!
FastLabelではソフトウェアエンジニアを積極採用中です!
技術的負債の対応含め、様々な技術的なチャレンジができる環境があります。
カジュアル面談もやっていますので、興味がある方はご応募いただけたら嬉しいです!
Discussion