🔄

SageMaker Training Jobを使う理由を整理しつつ、Terraformで試してみた

に公開

はじめに

Fusicのレオナです。本ブログでは、KaggleのDigit RecognizerコンペのMNISTデータセットを使用して、Terraformでインフラを構築してSageMaker Training Jobを動かすまでをハンズオン形式で書いていきます。

SageMaker Training Job とは

SageMaker Training Jobは、AWSが提供するマネージドな機械学習モデルの学習実行サービスです。
学習に必要なコンピューティングリソース(CPU/GPU)を指定すると、SageMakerが自動でインスタンスを起動し、学習が終わるとインスタンスを停止します。従量課金制なので、常時起動のインスタンスを管理する必要がありません。

弊社の@tobariが執筆したブログもあわせてご確認ください。
https://zenn.dev/fusic/articles/ef4715ddad5fd9

SageMaker NotebookではなくTraining Job を使う理由

SageMakerにはNotebookインスタンスという選択肢もありますが、本ブログではTraining Jobを採用しています。理由は2つあります。

  • 学習コードをNotebook形式で管理したくない: .ipynb はdiffが見づらく、Gitでのコードレビューやバージョン管理に向いていません。学習コードは通常のPythonスクリプト(.py)として管理し、再現性を確保したい目的があります。
    • marimoを使えば課題は解消されます。
  • Notebookインスタンスの削除し忘れが怖い: Notebookインスタンスは明示的に削除しない限り課金が続く。検証後に削除し忘れて数日放置すると、想定外のコストが発生します。Training Jobは学習完了後に自動でインスタンスが停止するため、この心配がありません。

全体の流れ

今回のハンズオンは以下の流れで進めます。

  1. Terraformでインフラ構築 — S3バケットとIAMロールを作成
  2. Kaggleデータセット取得 — MNIST画像データをダウンロードしてS3にアップロード
  3. 学習スクリプト作成 — PyTorch CNNの学習コード(SageMaker対応)
  4. Training Job実行 — SageMaker Python SDKでジョブを起動
  5. 結果確認 — モデルアーティファクトをS3から取得してローカルで推論テスト
  6. 環境削除 — terraform destroy

前提条件

  • AWS CLI v2が設定済み
  • Terraform >= 1.5
  • Python >= 3.10 + UV(パッケージ管理)
  • Kaggle CLI + API Token
  • Mac環境で実行

Kaggle API Tokenはこちらをご覧くださいhttps://github.com/Kaggle/kaggle-cli/blob/main/docs/README.md#authentication

ディレクトリ構造

root/
├── terraform/
│   ├── terraform.tf               
│   ├── providers.tf               
│   ├── locals.tf                  
│   ├── variables.tf              
│   ├── main.tf                   
│   ├── outputs.tf                 
│   └── terraform.tfvars.example   
├── scripts/
│   ├── upload_data.py             # Kaggleデータ → S3アップロード
│   └── run_training.py            # Training Job起動
├── src/
│   └── train.py                   # PyTorch CNN学習スクリプト
└── pyproject.toml                 

実装

Step 1: Terraformでインフラ構築

SageMaker Training Jobに必要なAWSリソースは2つです。

  • S3バケット: 学習データの格納先とモデル出力先
  • IAMロール: SageMakerがS3・CloudWatch Logs・ECRにアクセスするための権限

バージョン制約・Provider設定

terraform/terraform.tf
terraform {
  required_version = ">= 1.5"

  required_providers {
    aws = {
      source  = "hashicorp/aws"
      version = "~> 6.0"
    }
  }
}
terraform/providers.tf
provider "aws" {
  region  = var.aws_region
  profile = var.aws_profile

  default_tags {
    tags = {
      Project     = var.project_name
      Environment = "experimental"
      ManagedBy   = "terraform"
    }
  }
}

変数定義

terraform/variables.tf
variable "aws_profile" {
  description = "AWS CLI プロファイル名(未指定の場合はデフォルトプロファイルを使用)"
  type        = string
  default     = null
}

variable "aws_region" {
  description = "AWS リージョン"
  type        = string
  default     = "ap-northeast-1"
}

variable "bucket_force_destroy" {
  description = "terraform destroy 時に S3 バケット内のオブジェクトも削除するか"
  type        = bool
  default     = true
}

variable "project_name" {
  description = "プロジェクト名(リソース名のプレフィックスに使用)"
  type        = string
  default     = "sm-handson"
}

bucket_force_destroy = true にしておくと、terraform destroy 時にバケット内のオブジェクトも一緒に削除されます。検証用途なのでここでは true にしています。

ローカル変数

リソース名のプレフィックスは locals に切り出し、変更時に1箇所で済むようにしています。

terraform/locals.tf
locals {
  name_prefix = var.project_name
}

リソース定義

terraform/main.tf
# S3 バケット(学習データ・モデル出力)
resource "aws_s3_bucket" "sagemaker" {
  bucket_prefix = "${local.name_prefix}-"
  force_destroy = var.bucket_force_destroy
}

resource "aws_s3_bucket_server_side_encryption_configuration" "sagemaker" {
  bucket = aws_s3_bucket.sagemaker.id

  rule {
    apply_server_side_encryption_by_default {
      sse_algorithm = "AES256"
    }
  }
}

resource "aws_s3_bucket_public_access_block" "sagemaker" {
  bucket = aws_s3_bucket.sagemaker.id

  block_public_acls       = true
  block_public_policy     = true
  ignore_public_acls      = true
  restrict_public_buckets = true
}

# IAM ロール(SageMaker 実行用)
data "aws_iam_policy_document" "sagemaker_assume_role" {
  statement {
    effect  = "Allow"
    actions = ["sts:AssumeRole"]

    principals {
      type        = "Service"
      identifiers = ["sagemaker.amazonaws.com"]
    }
  }
}

resource "aws_iam_role" "sagemaker_execution" {
  name               = "${local.name_prefix}-sagemaker-execution-role"
  assume_role_policy = data.aws_iam_policy_document.sagemaker_assume_role.json
}

# IAM ポリシー: S3 アクセス(学習データ・モデル出力)
data "aws_iam_policy_document" "s3_access" {
  statement {
    effect = "Allow"
    actions = [
      "s3:DeleteObject",
      "s3:GetObject",
      "s3:ListBucket",
      "s3:PutObject",
    ]
    resources = [
      aws_s3_bucket.sagemaker.arn,
      "${aws_s3_bucket.sagemaker.arn}/*",
    ]
  }
}

resource "aws_iam_role_policy" "s3_access" {
  name   = "${local.name_prefix}-s3-access"
  role   = aws_iam_role.sagemaker_execution.id
  policy = data.aws_iam_policy_document.s3_access.json
}

# IAM ポリシー: CloudWatch Logs(学習ログ出力)
data "aws_iam_policy_document" "cloudwatch_logs" {
  statement {
    effect = "Allow"
    actions = [
      "logs:CreateLogGroup",
      "logs:CreateLogStream",
      "logs:DescribeLogStreams",
      "logs:PutLogEvents",
    ]
    resources = ["arn:aws:logs:*:*:log-group:/aws/sagemaker/*"]
  }
}

resource "aws_iam_role_policy" "cloudwatch_logs" {
  name   = "${local.name_prefix}-cloudwatch-logs"
  role   = aws_iam_role.sagemaker_execution.id
  policy = data.aws_iam_policy_document.cloudwatch_logs.json
}

# IAM ポリシー: ECR Pull(組み込みコンテナイメージ取得)
data "aws_iam_policy_document" "ecr_pull" {
  statement {
    effect = "Allow"
    actions = [
      "ecr:BatchCheckLayerAvailability",
      "ecr:BatchGetImage",
      "ecr:GetAuthorizationToken",
      "ecr:GetDownloadUrlForLayer",
    ]
    resources = ["*"]
  }
}

resource "aws_iam_role_policy" "ecr_pull" {
  name   = "${local.name_prefix}-ecr-pull"
  role   = aws_iam_role.sagemaker_execution.id
  policy = data.aws_iam_policy_document.ecr_pull.json
}

出力値

terraform/outputs.tf
output "sagemaker_role_arn" {
  description = "SageMaker 実行ロールの ARN"
  value       = aws_iam_role.sagemaker_execution.arn
}

output "s3_bucket_name" {
  description = "学習データ・モデル出力用 S3 バケット名"
  value       = aws_s3_bucket.sagemaker.id
}

output "aws_region" {
  description = "使用リージョン"
  value       = var.aws_region
}

インフラのデプロイ

Terminal
cd terraform
cp terraform.tfvars.example terraform.tfvars
terraform.tfvars
aws_profile          = "your-profile"
aws_region           = "ap-northeast-1"
project_name         = "sm-handson"
bucket_force_destroy = true
Terminal
terraform init
terraform plan
terraform apply

terraform output で出力値を確認します。sagemaker_role_arns3_bucket_name は後のステップで使用します。

terraform output
# sagemaker_role_arn = "arn:aws:iam::123456789012:role/sm-handson-sagemaker-execution-role"
# s3_bucket_name     = "sm-handson-xxxxxxxxxxxx"

Step 2: Kaggleデータセット取得・S3アップロード

KaggleのDigit RecognizerコンペからMNISTデータセットを取得し、S3にアップロードします。

scripts/upload_data.py
"""Kaggle Digit Recognizer データセットをダウンロードして S3 にアップロードする"""

from __future__ import annotations

import argparse
import subprocess
import zipfile
from pathlib import Path

import boto3


def download_kaggle_data(output_dir: Path) -> None:
    """Kaggle API でデータをダウンロード"""
    output_dir.mkdir(parents=True, exist_ok=True)

    print("Downloading Kaggle Digit Recognizer dataset...")
    subprocess.run(
        [
            "kaggle", "competitions", "download",
            "-c", "digit-recognizer",
            "-p", str(output_dir),
        ],
        check=True,
    )

    # ZIP を展開
    zip_path = output_dir / "digit-recognizer.zip"
    if zip_path.exists():
        print(f"Extracting {zip_path}...")
        with zipfile.ZipFile(zip_path, "r") as z:
            z.extractall(output_dir)
        zip_path.unlink()


def upload_to_s3(local_dir: Path, bucket: str, prefix: str, profile: str | None = None) -> None:
    """ローカルファイルを S3 にアップロード"""
    session = boto3.Session(profile_name=profile) if profile else boto3.Session()
    s3 = session.client("s3")

    for file_path in local_dir.glob("*.csv"):
        key = f"{prefix}/{file_path.name}"
        print(f"Uploading {file_path.name} -> s3://{bucket}/{key}")
        s3.upload_file(str(file_path), bucket, key)


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--bucket", required=True, help="S3 バケット名")
    parser.add_argument("--prefix", default="data/mnist", help="S3 プレフィックス")
    parser.add_argument("--profile", default=None, help="AWS CLI プロファイル名")
    parser.add_argument("--local-dir", default="./tmp/data")
    args = parser.parse_args()

    local_dir = Path(args.local_dir)
    download_kaggle_data(local_dir)
    upload_to_s3(local_dir, args.bucket, args.prefix, args.profile)
    print(f"\nData is ready at: s3://{args.bucket}/{args.prefix}/")


if __name__ == "__main__":
    main()

実行コマンドは以下のとおりです。

Terminal
BUCKET=$(cd terraform && terraform output -raw s3_bucket_name)

# 新形式トークン(KGAT_)の場合は環境変数で渡す
KAGGLE_API_TOKEN="KGAT_xxxxxxxxxxxx" \
  uv run --with kaggle --with boto3 --with pandas \
  python scripts/upload_data.py --bucket ${BUCKET} --profile your-profile

Step 3: 学習スクリプト作成

SageMaker Training Jobで実行する学習スクリプトを作成します。PyTorchのCNNでMNIST画像を10クラスに分類するシンプルなモデルです。

SageMaker Training Jobは、コンテナ内で学習スクリプトを実行する際に以下の環境変数を自動設定します。

環境変数 パス 用途
SM_CHANNEL_TRAIN /opt/ml/input/data/train 学習データの読み込み先
SM_MODEL_DIR /opt/ml/model 学習済みモデルの保存先
SM_OUTPUT_DATA_DIR /opt/ml/output/data メトリクス等の出力先

学習スクリプトではこれらの環境変数からパスを取得します。ハイパーパラメータはコマンドライン引数として渡されます。

src/train.py
"""SageMaker Training Job 用の学習スクリプト"""

from __future__ import annotations

import argparse
import json
import os
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split


# Dataset
class MNISTDataset(Dataset):
    """Kaggle Digit Recognizer CSV を読み込む Dataset"""

    def __init__(self, csv_path: str) -> None:
        df = pd.read_csv(csv_path)

        if "label" in df.columns:
            self.labels = torch.tensor(df["label"].values, dtype=torch.long)
            self.pixels = torch.tensor(
                df.drop(columns=["label"]).values, dtype=torch.float32
            )
        else:
            self.labels = None
            self.pixels = torch.tensor(df.values, dtype=torch.float32)

        # 正規化 (0-255 → 0-1) して 1x28x28 に reshape
        self.pixels = self.pixels / 255.0
        self.pixels = self.pixels.view(-1, 1, 28, 28)

    def __len__(self) -> int:
        return len(self.pixels)

    def __getitem__(self, idx: int):
        if self.labels is not None:
            return self.pixels[idx], self.labels[idx]
        return self.pixels[idx]


# Model
class SimpleCNN(nn.Module):
    """Conv → Pool → Conv → Pool → FC → FC"""

    def __init__(self, num_classes: int = 10, dropout: float = 0.25) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout1 = nn.Dropout2d(dropout)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))  # 28x28 → 14x14
        x = self.pool(F.relu(self.conv2(x)))  # 14x14 → 7x7
        x = self.dropout1(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout2(x)
        x = self.fc2(x)
        return x

def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * images.size(0)
    return total_loss / len(loader.dataset)


def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0.0
    correct = 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            total_loss += criterion(outputs, labels).item() * images.size(0)
            correct += (outputs.argmax(dim=1) == labels).sum().item()
    n = len(loader.dataset)
    return total_loss / n, correct / n


def main() -> None:
    parser = argparse.ArgumentParser()

    # ハイパーパラメータ
    parser.add_argument("--epochs", type=int, default=10)
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--lr", type=float, default=0.001)
    parser.add_argument("--dropout", type=float, default=0.25)
    parser.add_argument("--val-ratio", type=float, default=0.2)

    # SageMaker 環境変数(ローカル実行時はデフォルト値を使用)
    parser.add_argument("--data-dir", type=str,
                        default=os.environ.get("SM_CHANNEL_TRAIN", "./data"))
    parser.add_argument("--model-dir", type=str,
                        default=os.environ.get("SM_MODEL_DIR", "./model"))
    parser.add_argument("--output-dir", type=str,
                        default=os.environ.get("SM_OUTPUT_DATA_DIR", "./output"))

    args = parser.parse_args()
    print(f"Hyperparameters: {vars(args)}")

    device = torch.device("cpu")

    # データ読み込み
    train_csv = Path(args.data_dir) / "train.csv"
    dataset = MNISTDataset(str(train_csv))

    # Train / Validation 
    val_size = int(len(dataset) * args.val_ratio)
    train_size = len(dataset) - val_size
    train_ds, val_ds = random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False)

    # モデル
    model = SimpleCNN(num_classes=10, dropout=args.dropout).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    criterion = nn.CrossEntropyLoss()

    # 学習ループ
    best_val_acc = 0.0
    for epoch in range(1, args.epochs + 1):
        train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device)
        val_loss, val_acc = evaluate(model, val_loader, criterion, device)

        print(f"Epoch {epoch}/{args.epochs} | "
              f"Train Loss: {train_loss:.4f} | "
              f"Val Loss: {val_loss:.4f} | "
              f"Val Acc: {val_acc:.4f}")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            model_path = Path(args.model_dir) / "model.pth"
            model_path.parent.mkdir(parents=True, exist_ok=True)
            torch.save(model.state_dict(), str(model_path))

    # メトリクス保存
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    with open(output_dir / "metrics.json", "w") as f:
        json.dump({"best_val_accuracy": round(best_val_acc, 4)}, f, indent=2)

    print(f"Training complete. Best validation accuracy: {best_val_acc:.4f}")


if __name__ == "__main__":
    main()

Step 4: Training Job実行

SageMaker Python SDKの PyTorch Estimatorを使ってTraining Jobを起動します。組み込みコンテナを使う場合、framework_versionpy_version を指定するだけでコンテナイメージが自動選択されます。SageMaker DLCで公開されているバージョンの確認が必要になります。
https://github.com/aws/deep-learning-containers

scripts/run_training.py
"""SageMaker Training Job を起動するスクリプト"""

from __future__ import annotations

import argparse
from datetime import datetime

import boto3
import sagemaker
from sagemaker.pytorch import PyTorch


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--role-arn", required=True, help="SageMaker 実行ロール ARN")
    parser.add_argument("--bucket", required=True, help="S3 バケット名")
    parser.add_argument("--region", default="ap-northeast-1")
    parser.add_argument("--profile", default=None, help="AWS CLI プロファイル名")
    parser.add_argument("--prefix", default="data/mnist")
    parser.add_argument("--epochs", type=int, default=10)
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--lr", type=float, default=0.001)
    parser.add_argument("--instance-type", default="ml.m5.large")
    parser.add_argument("--wait", action="store_true", help="Job完了まで待機")
    args = parser.parse_args()

    sess = sagemaker.Session(
        boto_session=boto3.Session(
            region_name=args.region,
            profile_name=args.profile,
        )
    )

    train_input = f"s3://{args.bucket}/{args.prefix}"
    output_path = f"s3://{args.bucket}/output"
    timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
    job_name = f"mnist-cnn-{timestamp}"

    print(f"Job name    : {job_name}")
    print(f"Instance    : {args.instance_type}")
    print(f"Train data  : {train_input}")
    print(f"Output      : {output_path}")

    # PyTorch Estimator(ECR 組み込みコンテナ使用)
    estimator = PyTorch(
        entry_point="train.py",       # 学習スクリプト
        source_dir="src",             # スクリプトのディレクトリ
        role=args.role_arn,
        instance_count=1,
        instance_type=args.instance_type,
        framework_version="2.0.1",    # PyTorch バージョン
        py_version="py310",           # Python バージョン
        output_path=output_path,
        sagemaker_session=sess,
        hyperparameters={
            "epochs": args.epochs,
            "batch-size": args.batch_size,
            "lr": args.lr,
            "dropout": 0.25,
            "val-ratio": 0.2,
        },
        max_run=3600,
    )

    estimator.fit(
        inputs={"train": train_input},
        job_name=job_name,
        wait=args.wait,
        logs="All" if args.wait else None,
    )

    if args.wait:
        print(f"\nModel artifact: {estimator.model_data}")
    else:
        print(f"\nTraining job submitted: {job_name}")
        print(f"  aws sagemaker describe-training-job --training-job-name {job_name}")


if __name__ == "__main__":
    main()

entry_pointsource_dirを指定すると、SageMaker SDKがsrc/ディレクトリを自動でtar.gzに固めてS3にアップロードし、コンテナ内で展開してtrain.pyを実行します。

Terminal
ROLE_ARN=$(cd terraform && terraform output -raw sagemaker_role_arn)
BUCKET=$(cd terraform && terraform output -raw s3_bucket_name)

uv run --with "sagemaker>=2.200,<3" --with "botocore[crt]" --with boto3 \
  python scripts/run_training.py \
  --role-arn ${ROLE_ARN} \
  --bucket ${BUCKET} \
  --profile your-profile \
  --epochs 5 \
  --batch-size 64 \
  --lr 0.001 \

Step 5: 結果確認

AWS マネジメントコンソールからSageMaker AIを選択します。



上記画像の通り、Completedが表示されていたら、学習が正常に終了しています。

環境削除

Terminal
cd terraform
terraform destroy

bucket_force_destroy = true にしているので、S3バケット内のオブジェクトもまとめて削除されます。

最後に

SageMaker Training Jobの基本的な流れを、Terraform + 組み込みPyTorchコンテナで一通り動かしてみました。GPU付きのインスタンスが停止されたかどうかを気にしなくて良いのはコスト観点からとてもメリットだと感じてます。

Fusic 技術ブログ

Discussion