🌊

TrainiumでLLMを学習してInferentiaにデプロイする

2024/09/02に公開

はじめに

本記事では、AWS Trainium/Inferentiaを用いて、大規模言語モデル(LLM)を学習・推論する手順を解説します。

Trainium/Inferentiaとは

まずTrainiumとInferentiaについて簡単に説明します。

Trainium及びInferentiaはAWSによる独自設計の機械学習アクセラレータです。

Trainiumは訓練ワークロードに最適化されたデバイスです。各Trainiumデバイスには第2世代のNeuronコアであるNeuronコアv2が2つずつ搭載されています。Trainiumデバイスを搭載したEC2インスタンスはTrn1インスタンスと呼ばれ、搭載されているデバイスの数に応じて異なるサイズが用意されています。

Inferentiaは推論ワークロードに最適化されたデバイスです。第2世代となるInferentia2には、Trainiumと同様に、Neuronコアv2が2つずつ搭載されています。Trainiumとの違いは搭載されているNeuronLinkの数であり、デバイス間相互接続における帯域幅に差が出ます。Inferentia2デバイスを搭載したEC2インスタンスはInf2インスタンスと呼ばれ、搭載されているデバイスの数に応じて異なるサイズが用意されています。

Neuron SDKはTrn1/Inf2インスタンス上でワークロードを実行するためのSDKです。Neuron SDKを利用した分散学習ライブラリとしてNeuronX Distributedがあります。LLMの推論ライブラリにはTransformers NeuronXがあります。Optimum NeuronはHuggingFaceによって開発されている、NeuronX DistributedとTransformers NeuronXをTransformers互換のインターフェースで扱うためのライブラリです。Text Generation Inference(TGI)はLLMをデプロイするためのツールキットで、NeuronX TGIはそのTrainium/Inferentiaに対応したバージョンです。

使用するモデル・データセット

本記事ではTanuki-8Bを例に、Trainium/Inferentiaでの学習と推論の手順を解説します。Tanuki-8Bは、GENIAC 松尾研 LLM開発プロジェクトの成果物として公開されているモデルです。Llamaと同等のアーキテクチャであるため、既存のエコシステムを活用する事ができます。

データセットには、MinnadeChatデータセットichikara-instructionを使用します。MinnadeChatデータセットは、Tanuki-8Bの開発者たちによって作られた人手でアノテーションされたデータセットです。ichikara-instructionは、理研AIPによって提供されているデータセットです。

事前準備

以降の手順では、AWS CLIとCloudFormationを使用してリソースのプロビジョニングを行います。AWS CLIのインストールや初期設定がまだの場合はAWSのドキュメントを参考に事前に設定してください。

ParallelClusterを使用した分散学習

Tanuki-8BをParallelClusterで分散学習します。

環境構築

まずは環境構築を行います。クラスタの構成は下図の通りです。

VPC

VPCを作成します。以下のようなCloudFormationテンプレートを作成します。

vpc.yaml
vpc.yaml
AWSTemplateFormatVersion: "2010-09-09"
Description: "CloudFormation template to deploy a VPC"

Parameters:
  CidrBlock:
    Type: String
    Description: CIDR Block
    Default: 10.0.0.0/16

  VPCName:
    Type: String
    Description: Name of your VPC

  SubnetsAZ1:
    Type: AWS::EC2::AvailabilityZone::Name
    Description: Availability zone in which the subnets will be created.

  SubnetsAZ2:
    Type: AWS::EC2::AvailabilityZone::Name
    Description: Availability zone in which the subnets will be created.

Resources:
  VPC:
    Type: AWS::EC2::VPC
    Properties:
      EnableDnsSupport: true
      EnableDnsHostnames: true
      CidrBlock: !Ref CidrBlock
      Tags:
        - Key: Name
          Value: !Ref VPCName

  FlowLogsRole:
    Type: AWS::IAM::Role
    Properties:
      AssumeRolePolicyDocument:
        Version: "2012-10-17"
        Statement:
          - Effect: Allow
            Principal:
              Service: vpc-flow-logs.amazonaws.com
            Action: sts:AssumeRole
      Policies:
        - PolicyName: flowlogs-policy
          PolicyDocument:
            Version: "2012-10-17"
            Statement:
              - Effect: Allow
                Action:
                  - logs:CreateLogStream
                  - logs:PutLogEvents
                  - logs:DescribeLogGroups
                  - logs:DescribeLogStreams
                Resource: !GetAtt FlowLogsGroup.Arn

  FlowLogsGroup:
    Type: AWS::Logs::LogGroup
    Properties:
      RetentionInDays: 7

  FlowLogVPC:
    Type: AWS::EC2::FlowLog
    Properties:
      DeliverLogsPermissionArn: !GetAtt FlowLogsRole.Arn
      LogGroupName: FlowLogsGroup
      ResourceId: !Ref VPC
      ResourceType: VPC
      TrafficType: ALL

  InternetGateway:
    Type: AWS::EC2::InternetGateway

  GatewayToInternet:
    Type: AWS::EC2::VPCGatewayAttachment
    Properties:
      VpcId: !Ref VPC
      InternetGatewayId: !Ref InternetGateway

  NATGateway:
    Type: AWS::EC2::NatGateway
    Properties:
      AllocationId: !GetAtt ElasticIP.AllocationId
      SubnetId: !Ref PublicSubnet1

  ElasticIP:
    Type: AWS::EC2::EIP
    Properties:
      Domain: vpc

  PublicSubnet1:
    Type: AWS::EC2::Subnet
    DependsOn: VPC
    Properties:
      MapPublicIpOnLaunch: true
      VpcId: !Ref VPC
      CidrBlock: !Select [0, !Cidr [!GetAtt VPC.CidrBlock, 3, 14]]
      AvailabilityZone: !Ref SubnetsAZ1
      Tags:
        - Key: Name
          Value: !Join [" ", [!Ref VPCName, "Public Subnet -", !Ref SubnetsAZ1]]

  PublicSubnet2:
    Type: AWS::EC2::Subnet
    DependsOn: VPC
    Properties:
      MapPublicIpOnLaunch: true
      VpcId: !Ref VPC
      CidrBlock: !Select [1, !Cidr [!GetAtt VPC.CidrBlock, 3, 14]]
      AvailabilityZone: !Ref SubnetsAZ2
      Tags:
        - Key: Name
          Value: !Join [" ", [!Ref VPCName, "Public Subnet -", !Ref SubnetsAZ2]]

  PrivateSubnet1:
    Type: AWS::EC2::Subnet
    DependsOn: VPC
    Properties:
      VpcId: !Ref VPC
      CidrBlock: !Select [2, !Cidr [!GetAtt VPC.CidrBlock, 3, 14]]
      AvailabilityZone: !Ref SubnetsAZ1
      Tags:
        - Key: Name
          Value:
            !Join [" ", [!Ref VPCName, "Private Subnet -", !Ref SubnetsAZ1]]

  PublicRouteTable:
    Type: AWS::EC2::RouteTable
    Properties:
      VpcId: !Ref VPC

  PublicRoute:
    Type: AWS::EC2::Route
    Properties:
      RouteTableId: !Ref PublicRouteTable
      DestinationCidrBlock: 0.0.0.0/0
      GatewayId: !Ref InternetGateway

  PrivateRouteTable:
    Type: AWS::EC2::RouteTable
    Properties:
      VpcId: !Ref VPC

  PrivateRouteToInternet:
    Type: AWS::EC2::Route
    Properties:
      RouteTableId: !Ref PrivateRouteTable
      DestinationCidrBlock: 0.0.0.0/0
      NatGatewayId: !Ref NATGateway

  PublicSubnet1RouteTableAssociation:
    Type: AWS::EC2::SubnetRouteTableAssociation
    Properties:
      SubnetId: !Ref PublicSubnet1
      RouteTableId: !Ref PublicRouteTable

  PublicSubnet2RouteTableAssociation:
    Type: AWS::EC2::SubnetRouteTableAssociation
    Properties:
      SubnetId: !Ref PublicSubnet2
      RouteTableId: !Ref PublicRouteTable

  PrivateSubnet1RTAssociation:
    Type: AWS::EC2::SubnetRouteTableAssociation
    Properties:
      SubnetId: !Ref PrivateSubnet1
      RouteTableId: !Ref PrivateRouteTable

  DefaultSecurityGroup:
    Type: AWS::EC2::SecurityGroup
    Properties:
      GroupDescription: Default Security group
      VpcId: !Ref VPC

  DefaultSecurityGroupIngress:
    Type: AWS::EC2::SecurityGroupIngress
    Properties:
      GroupId: !Ref DefaultSecurityGroup
      IpProtocol: -1
      FromPort: -1
      ToPort: -1
      SourceSecurityGroupId: !Ref DefaultSecurityGroup

  DefaultSecurityGroupEgress:
    Type: AWS::EC2::SecurityGroupEgress
    Properties:
      GroupId: !Ref DefaultSecurityGroup
      IpProtocol: -1
      FromPort: -1
      ToPort: -1
      DestinationSecurityGroupId: !Ref DefaultSecurityGroup

  VPCESecurityGroup:
    Type: AWS::EC2::SecurityGroup
    Properties:
      GroupDescription: Security group for VPC Endpoint
      SecurityGroupIngress:
        - IpProtocol: tcp
          FromPort: 443
          ToPort: 443
          SourceSecurityGroupId: !Ref DefaultSecurityGroup
      SecurityGroupEgress:
        - IpProtocol: -1
          FromPort: -1
          ToPort: -1
          CidrIp: 0.0.0.0/0
      VpcId: !Ref VPC

  S3Endpoint:
    Type: AWS::EC2::VPCEndpoint
    Properties:
      RouteTableIds:
        - !Ref PublicRouteTable
        - !Ref PrivateRouteTable
      ServiceName: !Sub com.amazonaws.${AWS::Region}.s3
      VpcEndpointType: Gateway
      VpcId: !Ref VPC

  LogsEndpoint:
    Type: AWS::EC2::VPCEndpoint
    Properties:
      VpcEndpointType: Interface
      PrivateDnsEnabled: true
      ServiceName: !Sub com.amazonaws.${AWS::Region}.logs
      VpcId: !Ref VPC
      SubnetIds:
        - !Ref PrivateSubnet1
      SecurityGroupIds:
        - !Ref VPCESecurityGroup

Outputs:
  VPC:
    Value: !Ref VPC
    Description: ID of the VPC
    Export:
      Name: !Sub ${AWS::StackName}-VPC
  PublicSubnet1:
    Value: !Ref PublicSubnet1
    Description: ID of the public subnet
    Export:
      Name: !Sub ${AWS::StackName}-PublicSubnet1
  PublicSubnet2:
    Value: !Ref PublicSubnet2
    Description: ID of the public subnet
    Export:
      Name: !Sub ${AWS::StackName}-PublicSubnet2
  PrivateSubnet1:
    Value: !Ref PrivateSubnet1
    Description: ID of the private subnets
    Export:
      Name: !Sub ${AWS::StackName}-PrivateSubnet1
  DefaultSecurityGroup:
    Value: !Ref DefaultSecurityGroup
    Description: ID of the default security group

VPC名やリージョン、アベイラビリティーゾーン(AZ)を環境変数に設定します。AZはTrn1とInf2が利用可能なゾーンを選んでいます。[1]

export VPC_NAME=llm-vpc
export REGION=us-west-2
export AZ1=$(aws ec2 describe-availability-zones \
    --region ${REGION} \
    --query "AvailabilityZones[]" \
    --filters "Name=zone-id,Values=usw2-az4" \
    --query "AvailabilityZones[].ZoneName" \
    --output text)
export AZ2=$(aws ec2 describe-availability-zones \
    --region ${REGION} \
    --query "AvailabilityZones[]" \
    --filters "Name=zone-id,Values=usw2-az1" \
    --query "AvailabilityZones[].ZoneName" \
    --output text)

リソースをプロビジョニングします。

aws cloudformation deploy \
    --region ${REGION} \
    --capabilities CAPABILITY_IAM \
    --template-file vpc.yaml \
    --stack-name ${VPC_NAME} \
    --parameter-overrides \
        VPCName=${VPC_NAME} \
        SubnetsAZ1=${AZ1} \
        SubnetsAZ2=${AZ2}

S3

S3バケットを作成します。以下のようなCloudFormationテンプレートを作成します。

s3.yaml
s3.yaml
AWSTemplateFormatVersion: "2010-09-09"
Description: "CloudFormation template to create an S3 bucket with customizable name"

Parameters:
  BucketName:
    Type: String
    Description: "Name of the S3 bucket to be created"

Resources:
  S3Bucket:
    Type: "AWS::S3::Bucket"
    DeletionPolicy: Delete
    Properties:
      BucketName: !Ref BucketName
      AccessControl: Private
      BucketEncryption:
        ServerSideEncryptionConfiguration:
          - ServerSideEncryptionByDefault:
              SSEAlgorithm: AES256

Outputs:
  BucketName:
    Description: "Name of the newly created S3 bucket"
    Value: !Ref S3Bucket
  BucketARN:
    Description: "ARN of the newly created S3 bucket"
    Value: !GetAtt S3Bucket.Arn

バケット名を環境変数に設定します。

timestamp=$(date +%s)
random_string=$(openssl rand -hex 3)
export BUCKET_NAME=llm-bucket-${timestamp}-${random_string}

リソースをプロビジョニングします。

aws cloudformation deploy \
    --region ${REGION} \
    --capabilities CAPABILITY_IAM \
    --template-file s3.yaml \
    --stack-name ${BUCKET_NAME} \
    --parameter-overrides \
        BucketName=${BUCKET_NAME}

ParallelCluster

ParallelClusterを作成します。以下のようなCloudFormationテンプレートを作成します。

pcluster.yaml
pcluster.yaml
AWSTemplateFormatVersion: "2010-09-09"
Description: "CloudFormation template to create an Parallel Cluster"

Parameters:
  KeyName:
    Type: String
    Description: "Name of the key pair to be created"

  PublicSubnetId:
    Type: String
    Description: "ID of the VPC public subnet"

  PrivateSubnetId:
    Type: String
    Description: "ID of the VPC private subnet"

  Spot:
    Type: String
    Description: "Use Spot Instances if true, On-Demand if false"
    Default: "false"
    AllowedValues:
      - "true"
      - "false"

  NeuronVersion:
    Type: String
    Description: "Version of Neuron SDK"
    Default: v2.19.0

Mappings:
  ParallelCluster:
    Constants:
      Version: 3.10.1

Conditions:
  UseSpotInstances: !Equals
    - !Ref Spot
    - "true"

Resources:
  KeyPair:
    Type: "AWS::EC2::KeyPair"
    Properties:
      KeyName: !Ref KeyName

  PclusterClusterProvider:
    Type: AWS::CloudFormation::Stack
    Properties:
      TemplateURL: !Sub
        - https://${AWS::Region}-aws-parallelcluster.s3.${AWS::Region}.${AWS::URLSuffix}/parallelcluster/${Version}/templates/custom_resource/cluster.yaml
        - { Version: !FindInMap [ParallelCluster, Constants, Version] }

  PclusterCluster:
    Type: Custom::PclusterCluster
    Properties:
      ServiceToken: !GetAtt [PclusterClusterProvider, Outputs.ServiceToken]
      ClusterName: !Sub "c-${AWS::StackName}"
      ClusterConfiguration:
        Region: !Ref AWS::Region
        Image:
          Os: ubuntu2004
        HeadNode:
          InstanceType: c5.4xlarge
          Networking:
            SubnetId: !Ref PublicSubnetId
          Ssh:
            KeyName: !Ref KeyPair
          LocalStorage:
            RootVolume:
              Size: 1024
          CustomActions:
            OnNodeConfigured:
              Script: !Sub s3://neuron-s3/pcluster/post-install-scripts/neuron-installation/${NeuronVersion}/u20/pt/install_neuron.sh
          Iam:
            S3Access:
              - BucketName: neuron-s3
                EnableWriteAccess: false
        Scheduling:
          Scheduler: slurm
          SlurmSettings:
            QueueUpdateStrategy: DRAIN
          SlurmQueues:
            - Name: compute1
              CapacityType: !If
                - UseSpotInstances
                - SPOT
                - ONDEMAND
              ComputeSettings:
                LocalStorage:
                  RootVolume:
                    Size: 1024
                  EphemeralVolume:
                    MountDir: /local_storage
              ComputeResources:
                - Efa:
                    Enabled: true
                  InstanceType: trn1.32xlarge
                  MaxCount: 8
                  MinCount: 0
                  Name: queue1-i1
              Networking:
                SubnetIds:
                  - !Ref PrivateSubnetId
                PlacementGroup:
                  Enabled: true
              CustomActions:
                OnNodeConfigured:
                  Script: !Sub s3://neuron-s3/pcluster/post-install-scripts/neuron-installation/${NeuronVersion}/u20/pt/install_neuron.sh
              Iam:
                S3Access:
                  - BucketName: neuron-s3
                    EnableWriteAccess: false
        SharedStorage:
          - MountDir: /fsx
            Name: pclusterfsx
            StorageType: FsxLustre
            FsxLustreSettings:
              DeploymentType: PERSISTENT_2
              DataCompressionType: LZ4
              StorageCapacity: 1200
              PerUnitStorageThroughput: 125

Outputs:
  KeyPairId:
    Description: "The ID of the key pair"
    Value: !GetAtt KeyPair.KeyPairId
  HeadNodeIp:
    Description: The Public IP address of the HeadNode
    Value: !GetAtt [PclusterCluster, headNode.publicIpAddress]

作成したパブリックサブネットとプライベートサブネットのIDを取得し環境変数に設定します。

export PUBLIC_SUBNET_ID=$(aws cloudformation describe-stacks \
    --region ${REGION} \
    --stack-name ${VPC_NAME} \
    --query "Stacks[0].Outputs[?OutputKey=='PublicSubnet1'].OutputValue" \
    --output text)
export PRIVATE_SUBNET_ID=$(aws cloudformation describe-stacks \
    --region ${REGION} \
    --stack-name ${VPC_NAME} \
    --query "Stacks[0].Outputs[?OutputKey=='PrivateSubnet1'].OutputValue" \
    --output text)

クラスター名とキーペア名を環境変数に設定します。

export PCLUSTER_NAME=llm-pcluster
export KEY_NAME=my-key-pair

リソースをプロビジョニングします。

aws cloudformation deploy \
    --region ${REGION} \
    --capabilities CAPABILITY_NAMED_IAM CAPABILITY_AUTO_EXPAND \
    --template-file pcluster.yaml \
    --stack-name ${PCLUSTER_NAME} \
    --parameter-overrides \
        KeyName=${KEY_NAME} \
        PublicSubnetId=${PUBLIC_SUBNET_ID} \
        PrivateSubnetId=${PRIVATE_SUBNET_ID} \
        BucketName=${BUCKET_NAME}
aws cloudformation wait stack-create-complete \
    --region ${REGION} \
    --stack-name ${PCLUSTER_NAME}

FSxと作成したS3バケットの関連付けを行います。

fsx_id=$(aws cloudformation describe-stacks \
    --region ${REGION} \
    --stack-name c-${PCLUSTER_NAME} \
    --query "Stacks[0].Outputs[?OutputKey=='FSXIds'].OutputValue" \
    --output text)
aws fsx create-data-repository-association \
    --region ${REGION} \
    --file-system-id ${fsx_id} \
    --file-system-path / \
    --data-repository-path s3://${BUCKET_NAME} \
    --s3 "AutoImportPolicy={Events=[NEW,CHANGED,DELETED]},AutoExportPolicy={Events=[NEW,CHANGED,DELETED]}" \
    --batch-import-meta-data-on-create

パラメータストアから作成した秘密鍵を取得します。

key_pair_id=$(aws cloudformation describe-stacks \
    --region ${REGION} \
    --stack-name ${PCLUSTER_NAME} \
    --query "Stacks[0].Outputs[?OutputKey=='KeyPairId'].OutputValue" \
    --output text)
aws ssm get-parameter \
    --region ${REGION} \
    --name /ec2/keypair/${key_pair_id} \
    --with-decryption \
    --query "Parameter.Value" \
    --output text \
    > ~/.ssh/${KEY_NAME}.pem
chmod 400 ~/.ssh/${KEY_NAME}.pem

SSHの設定ファイルに追記します。

head_node_ip=$(aws cloudformation describe-stacks \
    --region ${REGION} \
    --stack-name ${PCLUSTER_NAME} \
    --query "Stacks[0].Outputs[?OutputKey=='HeadNodeIp'].OutputValue" \
    --output text)
cat <<EOF >> ~/.ssh/config
Host ${PCLUSTER_NAME}
    HostName ${head_node_ip}
    User ubuntu
    IdentityFile ~/.ssh/${KEY_NAME}.pem
EOF

SSHでヘッドノードに接続します。以降はヘッドノード上での作業になります。

ssh ${PCLUSTER_NAME}

準備

仮想環境を有効化します。

source ~/aws_neuron_venv_pytorch/bin/activate

Neuron SDK v2.19.0時点のサンプルコードをダウンロードします。リポジトリにタグが追加されていないので、コミットSHAでバージョンを指定しています。

commit_sha=0f2a90f6ba2dc8fb12833d85e48732ca36717611

wget https://raw.githubusercontent.com/aws-neuron/neuronx-distributed/${commit_sha}/examples/training/llama/tp_zero1_llama_hf_pretrain/logger.py
wget https://raw.githubusercontent.com/aws-neuron/neuronx-distributed/${commit_sha}/examples/training/llama/tp_zero1_llama_hf_pretrain/tp_zero1_llama_hf_pretrain.py

wget https://raw.githubusercontent.com/aws-neuron/neuronx-distributed/${commit_sha}/examples/training/llama/lr.py
wget https://raw.githubusercontent.com/aws-neuron/neuronx-distributed/${commit_sha}/examples/training/llama/modeling_llama_nxd.py
wget https://raw.githubusercontent.com/aws-neuron/neuronx-distributed/${commit_sha}/examples/training/llama/requirements.txt
wget https://raw.githubusercontent.com/aws-neuron/neuronx-distributed/${commit_sha}/examples/training/llama/training_utils.py

wget https://raw.githubusercontent.com/aws-neuron/neuronx-distributed/${commit_sha}/examples/training/checkpoint_converter.py

パッケージをインストールします。

pip install -r requirements.txt

データセットの前処理

MinnadeChatとichikara-instructionを教師ありファインチューニング(SFT)用に前処理します。

以下のスクリプトを作成します。

get_dataset.py
get_dataset.py
import argparse
import os
from itertools import chain

from datasets import (
    Dataset,
    Features,
    Sequence,
    Value,
    concatenate_datasets,
    load_dataset,
)
from transformers import AutoTokenizer


def minnade_to_oasst(row):
    row["message_id"] = row.pop("id")
    row["text"] = row.pop("body") or ""
    row["replies"] = []
    return row


def read_dataset_message_trees(dataset_name: str, split: str, revision: str):
    dataset = load_dataset(dataset_name, split=split, revision=revision)
    dataset = dataset.sort("created_at")

    trees: list[dict] = []
    for row in dataset:
        row = minnade_to_oasst(row)
        if row["parent_id"] is None:
            tree_dict = {
                "message_tree_id": row["message_id"],
                "prompt": row,
            }
            trees.append(tree_dict)
        else:
            for tree_dict in trees:

                def add_child(node: dict, new_node: dict):
                    if new_node["parent_id"] == node["message_id"]:
                        node["replies"].append(new_node)
                        return
                    for i, _ in enumerate(node["replies"]):
                        add_child(node["replies"][i], new_node)

                add_child(tree_dict["prompt"], row)

    return trees


def create_threads(node, threads, parents=None):
    parents = parents or []
    if not node:
        return
    thread = parents + [node]
    if not thread[-1]["replies"]:
        threads.append(thread)
    if node["replies"]:
        parents = thread
        for c in node["replies"]:
            create_threads(c, threads, parents)


def gen_thread(dataset_name: str, split: str, revision: str):
    trees = read_dataset_message_trees(dataset_name, split, revision)

    threads: list[list] = []
    for tree in trees:
        create_threads(tree["prompt"], threads)

    for thread in threads:
        if thread[0]["role"] == "system":
            for i, m in enumerate(thread):
                if i == 0:
                    continue
                if i % 2 == 0:
                    assert m["role"] == "assistant", m
                else:
                    m["role"] == "user", m
        else:
            for i, m in enumerate(thread):
                if i % 2 == 0:
                    assert m["role"] == "user", m
                else:
                    m["role"] == "assistant", m

        if thread[-1]["role"] == "user":
            thread = thread[:-1]
        if thread[-1]["role"] == "system":
            thread = thread[:-1]

        if thread:
            yield {
                "messages": [{"role": m["role"], "content": m["text"]} for m in thread]
            }


def load_minnade_dataset():
    return Dataset.from_generator(
        gen_thread,
        gen_kwargs={
            "dataset_name": "minnade/chat-daily",
            "split": "train",
            "revision": "2024-07-25",
        },
    )


def load_ichikara_dataset():
    dataset = load_dataset(
        "p1atdev/ichikara-instruction", "20231221-003", split="train"
    )
    return dataset.map(
        lambda example: {
            "messages": [
                {"role": "user", "content": example["text"]},
                {"role": "assistant", "content": example["output"]},
            ]
        },
        remove_columns=dataset.column_names,
    )


def main(args):
    save_path = os.path.expanduser(args.data_dir)
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    block_size = args.block_size
    features = Features(
        {
            "input_ids": Sequence(feature=Value(dtype="int32")),
            "labels": Sequence(feature=Value(dtype="int32")),
        }
    )

    tokenizer = AutoTokenizer.from_pretrained(args.model_name)

    BOS = [tokenizer.bos_token_id]
    EOS = [tokenizer.eos_token_id]
    BINST = tokenizer.encode("[INST]", add_special_tokens=False)
    EINST = tokenizer.encode("[/INST]", add_special_tokens=False)
    BSYS = tokenizer.encode("<<SYS>>\n", add_special_tokens=False)
    ESYS = tokenizer.encode("\n<</SYS>>\n\n", add_special_tokens=False)

    def tokenize(example):
        input_ids = []
        labels = []

        if example["messages"][0]["role"] == "system":
            system = example["messages"][0]["content"]
            messages = example["messages"][1:]
        else:
            system = None
            messages = example["messages"]

        for i, message in enumerate(messages):
            if message["role"] == "user":
                if i == 0 and system:
                    tokens = (
                        BOS
                        + BINST
                        + BSYS
                        + tokenizer.encode(system, add_special_tokens=False)
                        + ESYS
                        + tokenizer.encode(message["content"], add_special_tokens=False)
                        + EINST
                    )
                else:
                    tokens = (
                        BOS
                        + BINST
                        + tokenizer.encode(message["content"], add_special_tokens=False)
                        + EINST
                    )
                input_ids += tokens
                labels += [-100] * len(tokens)
            else:
                tokens = (
                    tokenizer.encode(message["content"], add_special_tokens=False) + EOS
                )
                input_ids += tokens
                labels += tokens

        return {"input_ids": input_ids, "labels": labels}

    def group_texts(examples):
        concatenated_examples = {
            k: list(chain.from_iterable(values for values in examples[k]))
            for k in features.keys()
        }
        total_length = len(concatenated_examples[list(features.keys())[0]])
        total_length = (total_length // block_size) * block_size
        result = {
            k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
        return result

    dataset = concatenate_datasets([load_minnade_dataset(), load_ichikara_dataset()])
    dataset = (
        dataset.map(
            tokenize,
            remove_columns=dataset.column_names,
            features=features,
        )
        .shuffle(seed=42)
        .map(
            group_texts,
            batched=True,
        )
        .shuffle(seed=42)
        .filter(lambda example: not all([e < 0 for e in example["labels"]]))
        .map(lambda example: {**example, "attention_mask": True})
    )
    print(dataset)
    dataset.save_to_disk(save_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_name",
        type=str,
        help="Model name.",
    )
    parser.add_argument(
        "--data_dir",
        type=str,
        help="Pre-tokenized dataset directory.",
    )
    parser.add_argument(
        "--block_size",
        type=int,
        default=8192,
        help="Block size.",
    )

    args = parser.parse_args()
    main(args)

モデル名とデータの出力先のパスを環境変数に設定します。

export HF_MODEL_NAME=weblab-GENIAC/Tanuki-8B-dpo-v1.0
export DATA_PATH=/fsx/data/minnade-ichikara

スクリプトを実行してデータセットを作成します。

python get_dataset.py \
    --model_name ${HF_MODEL_NAME} \
    --data_dir ${DATA_PATH}

ここでは、Llama 2 Chatのプロンプトフォーマットを使用しています。

チェックポイントの変換

以下のスクリプトを作成します。

convert_checkpoints.py
convert_checkpoints.py
import torch
import torch_xla.utils.serialization as xser
from transformers import AutoConfig, AutoModelForCausalLM

from checkpoint_converter import CheckpointConverterBase


class CheckpointConverterLlama(CheckpointConverterBase):
    def load_partial_xser(self, args, tp_rank, pp_rank):
        filename = self.get_input_filename(args, tp_rank, pp_rank, 1)
        partial_state = xser.load(filename)
        partial_state = {k: v.to(torch.bfloat16) for k, v in partial_state.items()}
        return partial_state

    def save_full(self, args, full_state):
        config = AutoConfig.from_pretrained(args.config)
        with torch.device("meta"):
            model = AutoModelForCausalLM.from_config(config)
        model.load_state_dict(full_state, assign=True)
        model.save_pretrained(args.output_dir)

    def load_full_state(self, args):
        model = AutoModelForCausalLM.from_pretrained(args.input_dir, torch_dtype="auto")
        if args.vocab_size > 0:
            with torch.no_grad():
                model.resize_token_embeddings(args.vocab_size)
        return model.state_dict()


if __name__ == "__main__":
    checkpoint_converter = CheckpointConverterLlama()
    parser = checkpoint_converter.get_arg_parser()
    parser.add_argument(
        "--vocab_size", type=int, default=-1, help="Vocabulary size of the model"
    )
    args, _ = parser.parse_known_args()
    checkpoint_converter.run(args)

設定ファイルをダウンロードします。

export MODEL_CONFIG_PATH=./tanuki-8b/config.json
mkdir ./tanuki-8b
curl https://huggingface.co/${HF_MODEL_NAME}/raw/main/config.json \
    | jq '. + {"vocab_size": 128256, "sequence_parallel_enabled": false, "selective_checkpoint_enabled": false, "move_model_to_device": false}' \
    > ${MODEL_CONFIG_PATH}

環境変数を設定します。

export CHECKPOINT_DIR=/fsx/checkpoints
export TP_DEGREE=32
export KV_REPLICATOR=4

スクリプトを実行して、TransformersのチェックポイントをNeuronX Distributedのフォーマットに変換します。

python convert_checkpoints.py \
    --input_dir ${HF_MODEL_NAME} \
    --output_dir ${CHECKPOINT_DIR}/pretrained_weight \
    --config ${MODEL_CONFIG_PATH} \
    --tp_size ${TP_DEGREE} \
    --kv_size_multiplier ${KV_REPLICATOR} \
    --qkv_linear True \
    --convert_from_full_state True \
    --vocab_size 128256

ここで、語彙サイズをLlama 3と同じ128256に変更しています。語彙サイズをLlama 3と揃えることで、Tanuki-8BはLlama 3 8Bと完全に等価なアーキテクチャになります。これにより、Llama 3 8BのNeuron Model Cacheを流用することができ、推論時のモデルのコンパイルを省略することができます。

学習の実行

以下のスクリプトを作成します。

tp_zero1_tanuki_8b.sh
tp_zero1_tanuki_8b.sh
#!/bin/bash

#############################################
# User defined parameters and env vars

SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )

export NEURON_CC_FLAGS="--model-type transformer --distribution-strategy=llm-training --cache_dir=$SCRIPT_DIR/neuron_compile_cache/"
export NEURON_FUSE_SOFTMAX=1

# Async Runtime
export NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS=3

# HOST OOM
export MALLOC_ARENA_MAX=64

# TP degree
: {TP_DEGREE:=32}
# KV replication size
: {KV_REPLICATOR:=4}
# 0: bf16; 1: mixed precision
USE_MIX_PRECISION=1
# 0: use pure DP; 1: use ZeRO-1
USE_ZERO_1=1
# global batch size
GBS=8
# micro batch size
MBS=1
# number of steps to run
TOTAL_STEPS=100
# warmup steps
WARMUP_STEPS=10
# Model config path
: {MODEL_CONFIG_PATH:="$SCRIPT_DIR/tanuki-8b"}
# Data path
: {DATA_PATH:="$SCRIPT_DIR/data/minnade-ichikara"}
# sequence length
SEQ_LEN=8192
# Checkpoint directory
: {CHECKPOINT_DIR:="$SCRIPT_DIR/checkpoints"}

#############################################

export NUM_NEURONCORES=32
NODE_ID=0
WORLD_SIZE=1
DISTRIBUTED_ARGS="--nproc_per_node $NUM_NEURONCORES"
if [ ! -z "$SLURM_NTASKS" ]; then
    WORLD_SIZE=$SLURM_NTASKS
    NODE_ID=$SLURM_NODEID
    MASTER_ADDRESS=(`scontrol show hostnames $SLURM_JOB_NODELIST`)
    DISTRIBUTED_ARGS="--nproc_per_node $NUM_NEURONCORES --nnodes $WORLD_SIZE --node_rank $NODE_ID --master_addr $MASTER_ADDRESS --master_port 44000"
    if [ $NODE_ID -eq 0 ]; then
        echo "WORLD_SIZE=$WORLD_SIZE"
        echo "NODE_ID=$NODE_ID"
        echo "MASTER_ADDRESS=$MASTER_ADDRESS"
        echo "DISTRIBUTED_ARGS=$DISTRIBUTED_ARGS"
    fi
    export FI_EFA_USE_DEVICE_RDMA=1
    export FI_PROVIDER=efa
fi

echo "WORLD_SIZE=$WORLD_SIZE"
echo "NODE_ID=$NODE_ID"
echo "MASTER_ADDRESS=$MASTER_ADDRESS"

sudo sysctl -w net.ipv4.ip_local_reserved_ports=44000,48620

export NEURON_RT_NUM_CORES=32
export NUM_NEURONCORES=$NEURON_RT_NUM_CORES
export TPU_NUM_DEVICES=$NEURON_RT_NUM_CORES
export TPU_CHIPS_PER_HOST_BOUNDS=$NEURON_RT_NUM_CORES

#############################################

EXTRA_ARGS=" "
if [ $USE_MIX_PRECISION -gt 0 ]; then
    EXTRA_ARGS+=" --use_mix_precision"
fi
if [ $USE_ZERO_1 -gt 0 ]; then
    EXTRA_ARGS+=" --use_zero_1"
fi

DP=$(($NEURON_RT_NUM_CORES * $WORLD_SIZE / $TP_DEGREE))
ACC_STEPS=$(($GBS / $MBS / $DP))


if [ $NEURON_EXTRACT_GRAPHS_ONLY -gt 0 ]; then
    STEPS_THIS_RUN=2
    OUTPUT_LOG=log_compile-$NODE_ID.log
elif [ -v PERF_TEST ] && [ $PERF_TEST -gt 0 ]; then
    STEPS_THIS_RUN=100
    OUTPUT_LOG=log_exe-$NODE_ID.log
else
    STEPS_THIS_RUN=-1
    OUTPUT_LOG=log_exe-$NODE_ID.log
fi

echo TP_DEGREE=$TP_DEGREE
echo KV_REPLICATOR=$KV_REPLICATOR
echo USE_MIX_PRECISION=$USE_MIX_PRECISION
echo USE_ZERO_1=$USE_ZERO_1
echo GBS=$GBS
echo MBS=$MBS
echo TOTAL_STEPS=$TOTAL_STEPS
echo WARMUP_STEPS=$WARMUP_STEPS
echo MODEL_CONFIG_PATH=$MODEL_CONFIG_PATH
echo DATA_PATH=$DATA_PATH
echo SEQ_LEN=$SEQ_LEN
echo CHECKPOINT_DIR=$CHECKPOINT_DIR

echo EXTRA_ARGS=$EXTRA_ARGS
echo DP=$DP
echo ACC_STEPS=$ACC_STEPS
echo STEPS_THIS_RUN=$STEPS_THIS_RUN
echo OUTPUT_LOG=$OUTPUT_LOG

torchrun $DISTRIBUTED_ARGS \
    tp_zero1_llama_hf_pretrain.py \
    --model_path $MODEL_CONFIG_PATH \
    --data_dir $DATA_PATH \
    --tensor_parallel_size $TP_DEGREE \
    --batch_size $MBS \
    --steps_this_run $STEPS_THIS_RUN\
    --max_steps $TOTAL_STEPS \
    --warmup_steps $WARMUP_STEPS \
    --lr 1e-5 \
    --weight_decay 0.1 \
    --beta1 0.9 \
    --beta2 0.999 \
    --grad_accum_usteps $ACC_STEPS \
    --print_grad_norm \
    --seq_len $SEQ_LEN \
    --sequence_parallel_enabled \
    --selective_checkpoint_enabled \
    --logging_interval 10 \
    --qkv_linear \
    --kv_replicator $KV_REPLICATOR \
    --use_flash_attention 1 \
    --checkpoint_freq $TOTAL_STEPS \
    --checkpoint_dir $CHECKPOINT_DIR \
    --pretrained_weight \
    $EXTRA_ARGS |& tee $OUTPUT_LOG
exit ${PIPESTATUS[0]}

ファイルのパーミッションを変更します。

chmod +x tp_zero1_tanuki_8b.sh

neuron_parallel_compileを使用して、事前コンパイルします。

sbatch --exclusive \
    --nodes 1 \
    --cpus-per-task 128 \
    --wrap="srun neuron_parallel_compile ./tp_zero1_tanuki_8b.sh"

事前コンパイルが終わると学習を実行します。同様のコマンドをneuron_parallel_compileなしで実行します。

sbatch --exclusive \
    --nodes 1 \
    --cpus-per-task 128 \
    --wrap="srun ./tp_zero1_tanuki_8b.sh"

チェックポイントの変換

モデルの出力先のディレクトリを環境変数に設定します。

export MODEL_OUTPUT_DIR=/fsx/models/tanuki-8b-sft

スクリプトを実行して、NeuronX DistributedのチェックポイントをTransformersのフォーマットに変換します。

latest_checkpoint=$(ls ${CHECKPOINT_DIR} | sort -t_ -k2 -rn | head -n1)
python convert_checkpoints.py \
    --input_dir ${CHECKPOINT_DIR}/${latest_checkpoint}/model \
    --output_dir ${MODEL_OUTPUT_DIR} \
    --config ${MODEL_CONFIG_PATH} \
    --tp_size ${TP_DEGREE} \
    --kv_size_multiplier ${KV_REPLICATOR} \
    --qkv_linear True \
    --load_xser True \
    --convert_to_full_state True

チャットテンプレートを書き換えて、トークナイザーを保存します。

temp="{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' '  + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}"
curl https://huggingface.co/${HF_MODEL_NAME}/raw/main/tokenizer_config.json \
    | jq --arg temp "${temp}" '. + {"chat_template": $temp}' \
    > ${MODEL_OUTPUT_DIR}/tokenizer_config.json
curl https://huggingface.co/${HF_MODEL_NAME}/raw/main/special_tokens_map.json -o ${MODEL_OUTPUT_DIR}/special_tokens_map.json
curl https://huggingface.co/${HF_MODEL_NAME}/raw/main/tokenizer.json -o ${MODEL_OUTPUT_DIR}/tokenizer.json

ECSへの推論環境のデプロイ

学習したTanuki-8BをECSにデプロイします。モデルのデプロイにはNeuronX TGIを使用します。

環境構築

まずは環境構築を行います。クラスタの構成は下図の通りです。

ECS

ECSクラスターを作成します。以下のようなCloudFormationテンプレートを作成します。

ecs.yaml
ecs.yaml
AWSTemplateFormatVersion: "2010-09-09"
Description: "CloudFormation template to create an ECS Cluster"

Parameters:
  ClusterName:
    Type: String
    Description: "Name of the ECS Cluster to be created"

  VpcId:
    Type: String
    Description: "ID of the VPC"

  PublicSubnetId1:
    Type: String
    Description: "ID of the VPC public subnet"

  PublicSubnetId2:
    Type: String
    Description: "ID of the VPC public subnet"

  PrivateSubnetId1:
    Type: String
    Description: "ID of the VPC private subnet"

  DefaultSecurityGroupId:
    Type: String
    Description: "ID of the default security group"

  LatestECSOptimizedAMI:
    Type: AWS::SSM::Parameter::Value<AWS::EC2::Image::Id>
    Description: "AMI ID"
    Default: /aws/service/ecs/optimized-ami/amazon-linux-2023/neuron/recommended/image_id

  InstanceType:
    Type: String
    Description: "Instance type"
    Default: inf2.xlarge

  Spot:
    Type: String
    Description: "Use Spot Instances if true, On-Demand if false"
    Default: "false"
    AllowedValues:
      - "true"
      - "false"

  Image:
    Type: String
    Description: "URI of the image"
    Default: ghcr.io/huggingface/neuronx-tgi:0.0.23

  BucketName:
    Type: String
    Description: "Name of the S3 bucket"

  ModelName:
    Type: String
    Description: "Name of the model"

Conditions:
  UseSpotInstances: !Equals
    - !Ref Spot
    - "true"

Resources:
  ECSSecurityGroup:
    Type: AWS::EC2::SecurityGroup
    Properties:
      GroupDescription: Security group for ECS
      SecurityGroupEgress:
        - IpProtocol: -1
          FromPort: -1
          ToPort: -1
          CidrIp: 0.0.0.0/0
      VpcId: !Ref VpcId

  ALBSecurityGroup:
    Type: AWS::EC2::SecurityGroup
    Properties:
      GroupDescription: Security group for ALB
      SecurityGroupIngress:
        - IpProtocol: tcp
          FromPort: 80
          ToPort: 80
          CidrIp: 0.0.0.0/0
      VpcId: !Ref VpcId

  ECSCluster:
    Type: AWS::ECS::Cluster
    Properties:
      ClusterName: !Ref ClusterName
      ClusterSettings:
        - Name: containerInsights
          Value: enabled
      Configuration:
        ExecuteCommandConfiguration:
          Logging: DEFAULT

  EcsInstanceRole:
    Type: AWS::IAM::Role
    Properties:
      AssumeRolePolicyDocument:
        Statement:
          - Effect: Allow
            Principal:
              Service:
                - ec2.amazonaws.com
            Action:
              - sts:AssumeRole
      Path: /
      ManagedPolicyArns:
        - arn:aws:iam::aws:policy/service-role/AmazonEC2ContainerServiceforEC2Role
        - arn:aws:iam::aws:policy/AmazonS3FullAccess

  IamRoleInstanceProfile:
    Type: AWS::IAM::InstanceProfile
    Properties:
      Roles:
        - !Ref EcsInstanceRole

  ECSLaunchTemplate:
    Type: AWS::EC2::LaunchTemplate
    Properties:
      LaunchTemplateData:
        ImageId: !Ref LatestECSOptimizedAMI
        SecurityGroupIds:
          - !Ref DefaultSecurityGroupId
          - !Ref ECSSecurityGroup
        IamInstanceProfile:
          Name: !Ref IamRoleInstanceProfile
        BlockDeviceMappings:
          - DeviceName: /dev/xvda
            Ebs:
              VolumeSize: 500
        UserData:
          Fn::Base64: !Sub |
            #!/bin/bash
            echo ECS_CLUSTER=${ClusterName} >> /etc/ecs/ecs.config
            sudo yum install -y https://s3.amazonaws.com/mountpoint-s3-release/latest/x86_64/mount-s3.rpm
            sudo mkdir /s3
            sudo mount-s3 --allow-other ${BucketName} /s3

  ECSAutoScalingGroup:
    Type: AWS::AutoScaling::AutoScalingGroup
    DependsOn:
      - ECSCluster
      - EcsInstanceRole
    Properties:
      MinSize: 0
      MaxSize: 3
      DesiredCapacity: 1
      MixedInstancesPolicy:
        LaunchTemplate:
          LaunchTemplateSpecification:
            LaunchTemplateId: !Ref ECSLaunchTemplate
            Version: !GetAtt ECSLaunchTemplate.LatestVersionNumber
          Overrides:
            - InstanceType: !Ref InstanceType
        InstancesDistribution:
          OnDemandPercentageAboveBaseCapacity: !If
            - UseSpotInstances
            - 0
            - 100
          SpotAllocationStrategy: price-capacity-optimized
      VPCZoneIdentifier:
        - !Ref PrivateSubnetId1

  EC2CapacityProvider:
    Type: AWS::ECS::CapacityProvider
    Properties:
      AutoScalingGroupProvider:
        AutoScalingGroupArn: !Ref ECSAutoScalingGroup
        ManagedScaling:
          Status: ENABLED
          TargetCapacity: 100
        ManagedTerminationProtection: DISABLED

  ClusterCPAssociation:
    Type: AWS::ECS::ClusterCapacityProviderAssociations
    Properties:
      Cluster: !Ref ClusterName
      CapacityProviders:
        - !Ref EC2CapacityProvider
      DefaultCapacityProviderStrategy:
        - Base: 0
          Weight: 1
          CapacityProvider: !Ref EC2CapacityProvider

  ECSTaskExecutionRole:
    Type: AWS::IAM::Role
    Properties:
      AssumeRolePolicyDocument:
        Statement:
          - Effect: Allow
            Principal:
              Service:
                - ecs-tasks.amazonaws.com
            Action:
              - sts:AssumeRole
      ManagedPolicyArns:
        - arn:aws:iam::aws:policy/service-role/AmazonECSTaskExecutionRolePolicy

  ECSLogGroup:
    Type: AWS::Logs::LogGroup
    Properties:
      RetentionInDays: 7

  ECSTaskDefinition:
    Type: AWS::ECS::TaskDefinition
    Properties:
      ContainerDefinitions:
        - Command:
            - --port
            - 8080
            - --model-id
            - !Ref ModelName
            - --max-batch-size
            - 1
            - --max-input-length
            - 3164
            - --max-total-tokens
            - 4096
          Essential: true
          Image: !Ref Image
          LogConfiguration:
            LogDriver: awslogs
            Options:
              awslogs-group: !Ref ECSLogGroup
              awslogs-region: !Ref AWS::Region
              awslogs-stream-prefix: ecs
          MemoryReservation: 1024
          MountPoints:
            - ContainerPath: /s3
              ReadOnly: true
              SourceVolume: s3
          Name: tgi
          PortMappings:
            - AppProtocol: http
              ContainerPort: 8080
              Protocol: tcp
          Privileged: true
      ExecutionRoleArn: !Ref ECSTaskExecutionRole
      IpcMode: host
      NetworkMode: bridge
      PlacementConstraints:
        - Expression: attribute:ecs.os-type == linux
          Type: memberOf
        - Expression: !Sub attribute:ecs.instance-type == ${InstanceType}
          Type: memberOf
      RequiresCompatibilities:
        - EC2
      Volumes:
        - Host:
            SourcePath: /s3
          Name: s3

  TargetGroup:
    Type: AWS::ElasticLoadBalancingV2::TargetGroup
    Properties:
      HealthCheckEnabled: true
      HealthCheckIntervalSeconds: 5
      HealthCheckPath: /health
      HealthCheckProtocol: HTTP
      HealthCheckTimeoutSeconds: 3
      HealthyThresholdCount: 2
      Port: 8080
      Protocol: HTTP
      ProtocolVersion: HTTP1
      TargetType: instance
      UnhealthyThresholdCount: 2
      VpcId: !Ref VpcId

  LoadBalancer:
    Type: AWS::ElasticLoadBalancingV2::LoadBalancer
    Properties:
      IpAddressType: ipv4
      LoadBalancerAttributes:
        - Key: idle_timeout.timeout_seconds
          Value: 600
      Scheme: internet-facing
      SecurityGroups:
        - !Ref DefaultSecurityGroupId
        - !Ref ALBSecurityGroup
      Subnets:
        - !Ref PublicSubnetId1
        - !Ref PublicSubnetId2
      Type: application

  ALBLister:
    Type: AWS::ElasticLoadBalancingV2::Listener
    Properties:
      DefaultActions:
        - TargetGroupArn: !Ref TargetGroup
          Type: forward
      LoadBalancerArn: !Ref LoadBalancer
      Port: 80
      Protocol: HTTP

  ECSService:
    Type: AWS::ECS::Service
    DependsOn: ALBLister
    Properties:
      Cluster: !Ref ECSCluster
      DesiredCount: 1
      EnableECSManagedTags: true
      HealthCheckGracePeriodSeconds: 3000
      LoadBalancers:
        - ContainerName: tgi
          ContainerPort: 8080
          TargetGroupArn: !Ref TargetGroup
      SchedulingStrategy: REPLICA
      ServiceName: tgi
      TaskDefinition: !Ref ECSTaskDefinition

Outputs:
  ECSCluster:
    Description: "The created cluster."
    Value: !Ref ECSCluster
  DNSName:
    Description: "The DNS name of the load balancer."
    Value: !GetAtt LoadBalancer.DNSName

クラスター名やVPC ID、サブネットID、セキュリティグループIDを環境変数に設定します。

export ECS_CLUSTER_NAME=neuronx-tgi
export VPC_ID=$(aws cloudformation describe-stacks \
    --region ${REGION} \
    --stack-name ${VPC_NAME} \
    --query "Stacks[0].Outputs[?OutputKey=='VPC'].OutputValue" \
    --output text)
export PUBLIC_SUBNET_ID1=$(aws cloudformation describe-stacks \
    --region ${REGION} \
    --stack-name ${VPC_NAME} \
    --query "Stacks[0].Outputs[?OutputKey=='PublicSubnet1'].OutputValue" \
    --output text)
export PUBLIC_SUBNET_ID2=$(aws cloudformation describe-stacks \
    --region ${REGION} \
    --stack-name ${VPC_NAME} \
    --query "Stacks[0].Outputs[?OutputKey=='PublicSubnet2'].OutputValue" \
    --output text)
export PRIVATE_SUBNET_ID1=$(aws cloudformation describe-stacks \
    --region ${REGION} \
    --stack-name ${VPC_NAME} \
    --query "Stacks[0].Outputs[?OutputKey=='PrivateSubnet1'].OutputValue" \
    --output text)
export DEFAULT_SG_ID=$(aws cloudformation describe-stacks \
    --region ${REGION} \
    --stack-name ${VPC_NAME} \
    --query "Stacks[0].Outputs[?OutputKey=='DefaultSecurityGroup'].OutputValue" \
    --output text)

リソースをプロビジョニングし、TGIをデプロイします。

aws cloudformation deploy \
    --region ${REGION} \
    --capabilities CAPABILITY_IAM \
    --template-file ecs.yaml \
    --stack-name ${ECS_CLUSTER_NAME} \
    --parameter-overrides \
        ClusterName=${ECS_CLUSTER_NAME} \
        VpcId=${VPC_ID} \
        PublicSubnetId1=${PUBLIC_SUBNET_ID1} \
        PublicSubnetId2=${PUBLIC_SUBNET_ID2} \
        PrivateSubnetId1=${PRIVATE_SUBNET_ID1} \
        DefaultSecurityGroupId=${DEFAULT_SG_ID} \
        BucketName=${BUCKET_NAME} \
        ModelName=/s3/models/tanuki-8b-sft
aws cloudformation wait stack-create-complete \
    --region ${REGION} \
    --stack-name ${ECS_CLUSTER_NAME}

APIの呼び出し

ALBのDNS名を取得し、環境変数に設定します。APIキーにはダミーの値を設定します。

dns_name=$(aws cloudformation describe-stacks \
    --region ${REGION} \
    --stack-name ${ECS_CLUSTER_NAME} \
    --query "Stacks[0].Outputs[?OutputKey=='DNSName'].OutputValue" \
    --output text)
export OPENAI_BASE_URL=http://${dns_name}/v1
export OPENAI_API_KEY=dummy

APIを呼び出します。

curl ${OPENAI_BASE_URL}/chat/completions \
    -s \
    -X POST \
    -H 'Content-Type: application/json' \
    -d '{
  "model": "tgi",
  "messages": [
    {"role": "system", "content": "あなたは親切なアシスタントです。"},
    {"role": "user", "content": "こんにちは。おもしろい話をしてください。"}
  ],
  "max_tokens": 1024,
  "temperature": 0.3,
  "top_p": 0.3,
  "stream": true
}' \
| sed 's/data://' \
| sed 's/\[DONE\]//' \
| jq --stream -j --unbuffered 'fromstream(1|truncate_stream(inputs))[0].delta.content // ""'

OpenAIのクライアントライブラリを使用することもできます。

from openai import OpenAI

client = OpenAI()

stream = client.chat.completions.create(
    model="tgi",
    messages=[
        {"role": "system", "content": "あなたは親切なアシスタントです。" },
        {"role": "user", "content": "こんにちは。怖い話をしてください。"}
    ],
    max_tokens=1024,
    temperature=0.3,
    top_p=0.3,
    stream=True
)
for chunk in stream:
    print(chunk.choices[0].delta.content or "", end="")

リソースの削除

リソースを削除します。

aws cloudformation delete-stack --region ${REGION} --stack-name ${PCLUSTER_NAME}
aws cloudformation delete-stack --region ${REGION} --stack-name ${ECS_CLUSTER_NAME}
aws cloudformation delete-stack --region ${REGION} --stack-name ${VPC_NAME}

S3バケットの削除時にはバケットを事前に空にする必要があります。

aws cloudformation delete-stack --region ${REGION} --stack-name ${BUCKET_NAME}
脚注
  1. ここではAZを2つ作成していますが、Parallel ClusterやECSで使用するのはAZ1のみです。ALBの仕様上、2つ以上のAZが必要なため2つ作成しています。 ↩︎

KARAKURI Techblog

Discussion