📝

SageMaker ドメイン削除時に EFS も自動削除してみた

に公開

Amazon EFS auto-mounting in Studio - Amazon SageMaker AI

Amazon SageMaker AI supports automatically mounting a folder in an Amazon EFS volume for each user in a domain. Using this folder, users can share data between their own private spaces.

デフォルト設定で SageMaker ドメインを作成すると、SageMaker ドメインに紐づく EFS ファイルシステムも自動作成されます。
また、EFS ファイルシステムに紐づく 2 つのセキュリティグループも自動作成されます。

SageMaker ドメインを削除しても EFS ファイルシステムとセキュリティグループは自動削除されない仕様になっています。

そのため、今回は SageMaker ドメインの削除時に EFS ファイルシステムとセキュリティグループも自動削除する方法を紹介します。

構成

  1. EventBridge ルールで SageMaker ドメインの削除を検知
  2. EventBridge ルールから Lambda 関数を呼び出す
  3. Lambda 関数から EFS ファイルシステムとセキュリティグループを削除

前提

  • SageMaker ドメインはデフォルトのクイックセットアップで作成済み

1. Lambda 関数の作成

以下の設定で作成しました。

  • ランタイム: Python 3.13
  • 実行ロール: AdministratorAccess 権限を付与した IAM ロール
  • タイムアウト: 60 秒
  • コード: 以下の通り
コード
import boto3
import json
import time

efs_client = boto3.client('efs')
ec2_client = boto3.client('ec2')

def lambda_handler(event, context):
    print("Event received:", json.dumps(event))
    
    detail = event.get('detail', {})
    domain_id = detail.get('requestParameters', {}).get('domainId')
    
    if not domain_id:
        print("Domain ID not found in event")
        return {'statusCode': 400, 'body': 'Domain ID not found'}
    
    print(f"Processing domain deletion for: {domain_id}")
    
    try:
        efs_id = get_efs_id_from_tag(domain_id)
        if efs_id:
            print(f"Found EFS: {efs_id}")
            
            sg_ids = get_security_groups(efs_id)
            print(f"Found Security Groups from mount targets: {sg_ids}")

            tagged_sg_ids = get_security_groups_from_tag(domain_id)
            print(f"Found Security Groups from tags: {tagged_sg_ids}")

            all_sg_ids = list(set(sg_ids + tagged_sg_ids))
            print(f"All Security Groups: {all_sg_ids}")
            
            delete_mount_targets(efs_id)
            
            time.sleep(5)
            
            if all_sg_ids:
                delete_sg_rules(all_sg_ids)
            
            delete_security_groups(all_sg_ids)

            delete_efs(efs_id)
            
            print("Cleanup completed successfully")
            return {'statusCode': 200, 'body': 'Cleanup completed'}
        else:
            print("No EFS found for domain")
            return {'statusCode': 200, 'body': 'No EFS found'}
            
    except Exception as e:
        print(f"Error: {str(e)}")
        return {'statusCode': 500, 'body': str(e)}

def get_efs_id_from_tag(domain_id):
    try:
        response = efs_client.describe_file_systems()
        
        for fs in response.get('FileSystems', []):
            efs_id = fs['FileSystemId']

            tags_response = efs_client.list_tags_for_resource(ResourceId=efs_id)
            tags = tags_response.get('Tags', [])

            for tag in tags:
                if tag['Key'] == 'ManagedByAmazonSageMakerResource':
                    tag_value = tag['Value']
                    if tag_value.endswith(domain_id):
                        return efs_id
        
        return None
    except Exception as e:
        print(f"Error getting EFS ID from tag: {str(e)}")
        return None

def get_security_groups(efs_id):
    try:
        response = efs_client.describe_mount_targets(FileSystemId=efs_id)
        sg_ids = set()
        
        for mount_target in response.get('MountTargets', []):
            eni_id = mount_target.get('NetworkInterfaceId')
            if eni_id:
                eni_response = ec2_client.describe_network_interfaces(NetworkInterfaceIds=[eni_id])
                for eni in eni_response.get('NetworkInterfaces', []):
                    for sg in eni.get('Groups', []):
                        sg_ids.add(sg['GroupId'])
        
        return list(sg_ids)
    except Exception as e:
        print(f"Error getting security groups from mount targets: {str(e)}")
        return []

def get_security_groups_from_tag(domain_id):
    try:
        sg_ids = []
        
        response = ec2_client.describe_security_groups(
            Filters=[
                {
                    'Name': 'tag:ManagedByAmazonSageMakerResource',
                    'Values': [f'arn:aws:sagemaker:*:*:domain/{domain_id}']
                }
            ]
        )
        
        for sg in response.get('SecurityGroups', []):
            sg_ids.append(sg['GroupId'])
        
        return sg_ids
    except Exception as e:
        print(f"Error getting security groups from tag: {str(e)}")
        return []

def delete_sg_rules(sg_ids):
    try:
        for attempt in range(3):
            print(f"Deleting SG rules - attempt {attempt + 1}")
            
            for sg_id in sg_ids:
                try:
                    response = ec2_client.describe_security_groups(GroupIds=[sg_id])
                    sg = response['SecurityGroups'][0]

                    for rule in sg.get('IpPermissions', []):
                        try:
                            ec2_client.revoke_security_group_ingress(
                                GroupId=sg_id,
                                IpPermissions=[rule]
                            )
                            print(f"Deleted ingress rule from {sg_id}")
                        except Exception as e:
                            print(f"Error deleting ingress rule from {sg_id}: {str(e)}")
                    
                    for rule in sg.get('IpPermissionsEgress', []):
                        try:
                            ec2_client.revoke_security_group_egress(
                                GroupId=sg_id,
                                IpPermissions=[rule]
                            )
                            print(f"Deleted egress rule from {sg_id}")
                        except Exception as e:
                            print(f"Error deleting egress rule from {sg_id}: {str(e)}")
                except Exception as e:
                    print(f"Error processing SG {sg_id}: {str(e)}")
            
            time.sleep(2)
    except Exception as e:
        print(f"Error deleting SG rules: {str(e)}")

def delete_security_groups(sg_ids):
    try:
        for attempt in range(5):
            remaining_sgs = []
            
            for sg_id in sg_ids:
                try:
                    ec2_client.delete_security_group(GroupId=sg_id)
                    print(f"Deleted security group: {sg_id}")
                except Exception as e:
                    print(f"Error deleting security group {sg_id} (attempt {attempt + 1}): {str(e)}")
                    remaining_sgs.append(sg_id)
            
            if not remaining_sgs:
                print("All security groups deleted successfully")
                break
            
            sg_ids = remaining_sgs
            if attempt < 4:
                print(f"Retrying deletion of remaining security groups...")
                time.sleep(3)
    except Exception as e:
        print(f"Error deleting security groups: {str(e)}")

def delete_mount_targets(efs_id):
    try:
        response = efs_client.describe_mount_targets(FileSystemId=efs_id)
        for mount_target in response.get('MountTargets', []):
            try:
                efs_client.delete_mount_target(MountTargetId=mount_target['MountTargetId'])
                print(f"Deleted mount target: {mount_target['MountTargetId']}")
            except Exception as e:
                print(f"Error deleting mount target: {str(e)}")
    except Exception as e:
        print(f"Error deleting mount targets: {str(e)}")

def delete_efs(efs_id):
    try:
        for attempt in range(10):
            try:
                efs_client.delete_file_system(FileSystemId=efs_id)
                print(f"Deleted EFS: {efs_id}")
                return
            except Exception as e:
                if "has mount targets" in str(e) and attempt < 9:
                    print(f"Mount targets still exist, retrying... (attempt {attempt + 1})")
                    time.sleep(3)
                else:
                    raise
    except Exception as e:
        print(f"Error deleting EFS: {str(e)}")

2. EventBridge ルールの作成

以下の設定で作成しました。

  • イベントパターン: 以下の通り
{
  "source": ["aws.sagemaker"],
  "detail-type": ["AWS API Call via CloudTrail"],
  "detail": {
    "eventSource": ["sagemaker.amazonaws.com"],
    "eventName": ["DeleteDomain"]
  }
}
  • ターゲット: 手順 1 で作成した Lambda 関数
  • IAM ロール: 自動作成された IAM ロール

3. 動作確認

SageMaker コンソールからドメインを削除します。

1 分程度経過後に EFS ファイルシステムとセキュリティグループが削除されていれば成功です。

まとめ

今回は SageMaker ドメインの削除時に EFS ファイルシステムとセキュリティグループも自動削除する方法を紹介しました。
どなたかの参考になれば幸いです。

参考資料

Discussion