webアプリケーションに機械学習機能を組み込む
はじめに
チーム開発でwebアプリケーションに機械学習機能を組み込むタスクがあったため、備忘録としてまとめます。
今回の機械学習タスクは、awsのsqsとlambda上でのbertによる文章のポジティブ・ネガティブ判定です。
概要
機械学習タスクを組み込む際に問題になるのが推論時間だと思います。
一般的にAPIサーバで同期処理を行うとタイムアウトの可能性が出てきます。
そこで、一旦messaging queueに入れてから、lambdaで処理をし、dbを更新することで非同期処理を実現しています。
アーキテクチャ
省略していますが、大まかなアーキテクチャは以下の通りです。
apiサーバ
APIのサービスはapp runnerでNestJSで以下のようにAPIを構築しています(graphql)。
resolverから直接このサービスを呼び出しています。
import { SQSClient, SendMessageCommand } from '@aws-sdk/client-sqs';
import { Injectable } from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import { Message } from './dto/Message';
@Injectable()
export class SqsService {
private readonly sqsClient: SQSClient;
private readonly SQS_QUEUE_URL: string;
constructor(private readonly configService: ConfigService) {
this.SQS_QUEUE_URL = this.configService.get('SQS_QUEUE_URL', '');
this.sqsClient = new SQSClient({ region: this.configService.get('SQS_REGION') });
}
async publishToQueue(message: Message) {
const command = new SendMessageCommand({
QueueUrl: this.SQS_QUEUE_URL,
MessageBody: JSON.stringify(message),
MessageGroupId: message.id,
});
return await this.sqsClient.send(command);
}
}
Messaging Queue
メッセージキュー(MQ)は非同期処理を可能にする通信手段で、producer側から、MQにメッセージをパブリッシュします。メッセージはキューに溜まっていき、consumer側で、メッセージを受け取り、処理を開始します。
MQにはrabbit MQなどのオープンソースのものもありますが、今回はawsのsqsを使いました。
Lambda
イベントトリガーをsqsへのパブリッシュとし、パブリッシュが起こると発火します。
lambdaへのデプロイは、docker imageをECRにあげ、それを用いています。
lambdaではpythonで以下のように実装しました。
from __future__ import annotations
import json
from core.task import task
def handler(event, context) -> str:
params = json.loads(event['Records'][0]['body'])
input_text: str = params['content']
positive_degree: int = task(input_text) # ポジネガ判定
post_request() # APIにrequest投げる
taskでポジネガ判定をしています。(学習済みのBERTを使っているため今回は割愛します。)
taskの処理後、APIを叩きに行き、dbのpositive/negaiveのカラムをアップデートします。
大きなモデルを使う際は、メモリやエフェメラルストレージや、タイムアウトは長めに取っておきましょう。
また、クレデンシャル情報をparameter storeから引っ張ってくるなどする場合は適切なロールをつけておきましょう。
まとめ
本記事ではwebアプリケーションに機械学習機能を組み込む方法を紹介しました。
モデリングや予測などにも応用が可能です。
Discussion