🙌

pytorchをSageMaker上で動かす方法

2023/12/13に公開

記事の対象者

sagemakerとpytorchを組み合わせて使ってみたいユーザー

できること

sagemakerAPIとpytorchの組み合わせによる簡易訓練、デプロイの流れを理解できる
カスタマイズする箇所がわかりやすいのでカスタムモデルの構築にスムーズに移行できる

関連ファイル

feature_extract_cifar10.py

検証日

2023-12-13

ディレクトリ構成

root直下
- [初心者向け]Amazon SageMakerでPyTorch.ipynb
- feature_extract_cifar10.py
## 必要モデルのinstall
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models

### 補足: sagemaker_trainingライブラリは今回は使用しない
## 環境変数の設定 

%%time

import sagemaker
import os
import boto3
import re
import numpy as np

sagemaker_session = sagemaker.Session()

role = sagemaker.get_execution_role()
region = boto3.Session().region_name

bucket='mlearning-bucket'
prefix = 'sagemaker/cnn-cifar10'
# customize to your bucket where you have stored the data
bucket_path = 'https://s3-{}.amazonaws.com/{}'.format(region,bucket)

### 補足: mlearning-bucketの箇所は自分の使用するsagemaker用のバケット名を指定する
## データセットのtransform設定
%%time
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])

train_data = datasets.CIFAR10(root='../data', train=True, download=True, transform=transform)
test_data = datasets.CIFAR10(root='../data', train=False, download=True, transform=transform)

### 補足: torchvisionのdatasetsからデータのダウンロード
## s3へのデータのアップロード
%%time
inputs = sagemaker_session.upload_data(path='../data', bucket=bucket, key_prefix=prefix)
print('input spec (in this case, just an S3 path): {}'.format(inputs))

### 補足: s3のsagemaker/cnn-cifar10ディレクトリにアップロード
## sagemakerのpytorchモデルを用いた予測クラスの生成
from sagemaker.pytorch import PyTorch

hyper_param = {
    'epochs':100,
    'batch-size': 100,
    'lr': 0.01,
    'momentum': 0.9,
}

estimator = PyTorch(entry_point='feature_extract_cifar10.py',
                            hyperparameters=hyper_param,
                            role=role,
                            framework_version='1.2.0',
                            py_version='py3',
                            train_instance_count=2,
                            train_instance_type='ml.c5.xlarge')
### 補足: sagemakerのライブラリのアップデートに伴い、以下二つのオプションの設定が必要となったので、設定。;framework_version='1.2.0', py_version='py3'
## 訓練の実行
estimator.fit({'training': inputs}, logs=True)
## モデルのデプロイ
predictor = estimator.deploy(initial_instance_count=1, instance_type='ml.m4.xlarge')
## モデルの検証
import numpy as np
correct = 0
total = 0


## テストのための設定
test_loader = DataLoader(test_data, batch_size=100, shuffle=False)

with torch.no_grad():
    for data in test_loader:
        images, labels = data
        outputs = predictor.predict(images.numpy())
        _, predicted = torch.max(torch.from_numpy(np.array(outputs)), 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))
### 補足2:feature_extract_cifarの中身がs3のデータを使用する形になっていないので、s3の該当のデータをとってくるよう修正が必要

## モデルの後片付け
endpointName = "<<ここに上記で作り、返却されているsagemakerのendpoint名を記載>>"

sagemaker_client = boto3.client('sagemaker')
response = sagemaker_client.delete_endpoint(
    EndpointName=endpointName
)

## 補足:boto3はawsのコンテナであればデフォルトでinstallされている可能性が高い

最新コードは以下
https://github.com/tamae1111/publicPytorch/blob/main/arrangedAmazonSageMakerでPyTorch.ipynb

Discussion