🔖

accelerateでマルチノード学習を行う

に公開

はじめに

accelerateとは

実行環境

machhine ubuntu GPU1 GPU2 cuda IP
main 24.04 TITAN X (Pascal) 12G TITAN X (Pascal) 12G cuda:12.6(nvidia-smi),
DriverVersion:560.35.03
10.0.0.3
sub 24.04 Quadro RTX 8000 48G (7.5) Quadro RTX 8000 48G (7.5) cuda:12.6(nvidia-smi),
DriverVersion:560.35.03
10.0.0.6

環境構築

  • 鍵作成
    • mainマシンのノードで実行
    • ssh-keygen -t rsa -f ./id_rsa
      
  • docker image
    accelerate_img
        # ベースとなるPyTorchやCUDAのイメージを指定
        FROM pytorch/pytorch:2.3.0-cuda12.1-cudnn8-devel
        
        # SSHサーバーとsudoをインストール
        RUN apt-get update && apt-get install -y openssh-server sudo
        
        # SSH接続を許可するための設定
        RUN mkdir /var/run/sshd
        # rootユーザーでのSSHログインを許可
        RUN sed -i 's/#PermitRootLogin prohibit-password/PermitRootLogin yes/' /etc/ssh/sshd_config
        
        # SSHポートを開放
        EXPOSE 22
        
        # 必要なPythonライブラリをインストール
        RUN pip install transformers accelerate pandas numpy
        
        # コンテナ起動時にSSHサーバーを開始する
        CMD ["/usr/sbin/sshd", "-D"]
    
  • docker compose
    • network_mode: "host"
      • これを追加しないと,dockerコンテナ内のネットワークを使用するため
      • 今回は,ホストマシンのネットワークを使用して並列学習を行いたいため,こちらを追加
    • ipc:host
      • dockerのコンテナの共有メモリだと足りないことが多いため,ホストマシンの共有メモリを使用するように指定
          [rank1]: Error while creating shared memory segment /dev/shm/nccl-n0KBY8 (size 9637888)
        
  • mainのマシンで実行
    • docker compose -f docker-compose_main.yaml up -d
      
      docker-compose_main.yaml
      services:
          deepspeed_test:
            image: accelerate_img
            container_name: acc_main
            
            volumes:
              - .:/acc_test
              - ./acc_key:/root/.ssh/id_rsa
              - ./acc_key.pub:/root/.ssh/authorized_keys
            
            network_mode: "host"
            ipc: host
            
            deploy:
              resources:
                reservations:
                  devices:
                    - driver: nvidia
                      device_ids: ['0', '1']
                      capabilities: [gpu]
            
            environment:
                TZ: Asia/Tokyo
                
            tty: true
            command: sleep infinity
      
  • subのマシンで実行
    • docker compose -f docker-compose_sub.yaml up -d
      
      docker-compose_sub.yaml
      services:
          deepspeed_test:
            image: accelerate_img
            container_name: acc_sub
            
            volumes:
              - .:/acc_test
              - ./acc_key.pub:/root/.ssh/authorized_keys
        
            network_mode: "host"
            ipc: host
            
            deploy:
              resources:
                reservations:
                  devices:
                    - driver: nvidia
                      device_ids: ['0', '1']
                      capabilities: [gpu]
            
            environment:
                TZ: Asia/Tokyo
                
            tty: true
            command: sleep infinity
      
  • それぞれのマシンで以下のbashを実行
    py_librarly.sh
      pip install transformers==4.53.0
      pip install fugashi
      pip install unidic_lite
      pip install torch
      pip install pandas
      pip install torch==2.7.1
      pip install torchvision==0.16.0
    

プログラム

classifier_acceraleter_BERT.py

import torch
import pandas as pd
import numpy as np
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from torch.optim import AdamW
from torch.utils.data import TensorDataset, random_split, DataLoader, RandomSampler, SequentialSampler
from accelerate import Accelerator
from torch.utils.data.distributed import DistributedSampler

def main():
    # 1. Acceleratorの初期化
    accelerator = Accelerator()

    MODEL_NAME = 'tohoku-nlp/bert-base-japanese-whole-word-masking'
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

    model = AutoModelForSequenceClassification.from_pretrained(
        MODEL_NAME,
        torch_dtype="auto",
        num_labels=2,
        output_attentions=False,
        output_hidden_states=False,
        use_safetensors=True
    )
    
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        model.config.pad_token_id = model.config.eos_token_id


    try:
        df = pd.read_csv("all_text.tsv", delimiter='\t', header=None, names=['media_name', 'label', 'NaN', 'sentence'])
    except FileNotFoundError:
        if accelerator.is_main_process:
            print("Error: all_text.tsv not found. Please check the file path.")
        exit()

    df = df.dropna(subset=['sentence', 'label'])
    df['label'] = pd.to_numeric(df['label'], errors='coerce')
    df = df.dropna(subset=['label'])
    df['label'] = df['label'].astype(int)

    sentences = df.sentence.values
    labels = df.label.values

    input_ids = []
    attention_masks = []
    MAX_LEN = 128

    for sent in sentences:
        encoded_dict = tokenizer.encode_plus(
            sent,
            add_special_tokens=True,
            max_length=MAX_LEN,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )
        input_ids.append(encoded_dict['input_ids'])
        attention_masks.append(encoded_dict['attention_mask'])

    input_ids = torch.cat(input_ids, dim=0)
    attention_masks = torch.cat(attention_masks, dim=0)
    labels = torch.tensor(labels)

    dataset = TensorDataset(input_ids, attention_masks, labels)

    train_size = int(0.9 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    if accelerator.is_main_process:
        print(f'Number of training data: {train_size}')
        print(f'Number of validation data: {val_size}')

    batch_size = 124

    train_dataloader = DataLoader(
        train_dataset,
        sampler=RandomSampler(train_dataset),
        batch_size=batch_size
    )
    
    validation_dataloader = DataLoader(
        val_dataset,
        sampler=SequentialSampler(val_dataset),
        batch_size=batch_size
    )
    print(f'Number of batches in training dataloader: {len(train_dataloader)}')
    print(f'Number of batches in validation dataloader: {len(validation_dataloader)}')
    optimizer = AdamW(model.parameters(), lr=2e-5)
    print(f'Wrapping starts...')

    # 2. Prepare the model, optimizer, and dataloaders
    model, optimizer, train_dataloader, validation_dataloader = accelerator.prepare(
        model, optimizer, train_dataloader, validation_dataloader
    )
    print(f'Device to use: {accelerator.device}')
    max_epoch = 1

    if accelerator.is_main_process:
        print("\nStarting training...")

    for epoch in range(max_epoch):
        print(f"Starting epoch {epoch + 1}")
        # --- Training loop ---
        model.train()
        total_train_loss = 0
        for i,batch in enumerate(train_dataloader):
            print(f'Processing batch {i + 1}/{len(train_dataloader)}...')
            optimizer.zero_grad()
            outputs = model(
                input_ids=batch[0],
                attention_mask=batch[1],
                labels=batch[2]
            )
            loss = outputs.loss
            total_train_loss += loss.item()
            
            # Use accelerator.backward(loss)
            accelerator.backward(loss)
            
            optimizer.step()
        
        avg_train_loss = total_train_loss / len(train_dataloader)

        # --- Validation loop ---
        model.eval()
        total_val_loss = 0
        if len(validation_dataloader) > 0:
            with torch.no_grad():
                for batch in validation_dataloader:
                    outputs = model(
                        input_ids=batch[0],
                        attention_mask=batch[1],
                        labels=batch[2]
                    )
                    loss = outputs.loss
                    total_val_loss += loss.item()
        
        avg_val_loss = total_val_loss / len(validation_dataloader) if len(validation_dataloader) > 0 else 0

        # Display logs only on the main process
        if accelerator.is_main_process:
            print(f'\nEpoch {epoch + 1}/{max_epoch}')
            print(f'  Train Loss: {avg_train_loss:.4f}')
            if avg_val_loss > 0:
                print(f'  Valid Loss: {avg_val_loss:.4f}')

    if accelerator.is_main_process:
        print("\nTraining complete.")
    
    # The prediction part below needs adjustment depending on the implementation, 
    # but it is a basic conversion example.
    # To get accurate accuracy, you need to aggregate the results of all processes with accelerator.gather().
    if val_size > 0 and accelerator.is_main_process:
        model.eval()
        print("\nRunning predictions on validation data...")
        for batch in validation_dataloader:
            with torch.no_grad():
                outputs = model(
                    input_ids=batch[0],
                    attention_mask=batch[1]
                )
            
            logits = outputs.logits
            logits_df = pd.DataFrame(logits.cpu().numpy(), columns=['logit_0', 'logit_1'])
            pred_labels = np.argmax(logits.cpu().numpy(), axis=1)
            pred_df = pd.DataFrame(pred_labels, columns=['pred_label'])
            label_df = pd.DataFrame(batch[2].cpu().numpy(), columns=['true_label'])
            accuracy_df = pd.concat([logits_df, pred_df, label_df], axis=1)
            print("\nSample prediction results for a validation batch:")
            print(accuracy_df.head())
            accuracy = (accuracy_df['pred_label'] == accuracy_df['true_label']).mean()
            print(f"\nAccuracy for this batch: {accuracy:.4f}")
                        break # Evaluate only on the first batch


if __name__ == '__main__':
    main()
  • コンフィグを使用しないばあいは以下のコマンドを実行することができる
    • num_processes
      • いくつのGPUを使用するか
    • multi_gpu
      • 複数のGPUを使用する場合に必要

1node2GPU

  • accelerateのconfigを作成
    accelerateのコンフィグ作成
    # accelerate config
      ----------------------------------------------------------------------------------------------------------------------------------------
    In which compute environment are you running?
      This machine
      ----------------------------------------------------------------------------------------------------------------------------------------
    Which type of machine are you using?
      multi-GPU
      How many different machines will you use (use more than 1 for multi-node training)? [1]: 1
      Should distributed operations be checked while running for errors? This can avoid timeout issues but will be slower. [yes/NO]: NO
      Do you wish to optimize your script with torch dynamo?[yes/NO]:NO
      Do you want to use DeepSpeed? [yes/NO]: NO
      Do you want to use FullyShardedDataParallel? [yes/NO]: NO
      Do you want to use Megatron-LM ? [yes/NO]: NO
      How many GPU(s) should be used for distributed training? [1]:2
      What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:all
      Would you like to enable numa efficiency? (Currently only supported on NVIDIA hardware). [yes/NO]: NO
      ----------------------------------------------------------------------------------------------------------------------------------------
    Do you wish to use mixed precision?
      no
      accelerate configuration saved at /root/.cache/huggingface/accelerate/default_config.yaml
    
default_config.yaml

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
enable_cpu_affinity: false
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

実行結果
# accelerate launch classifier_acceraleter_BERT.py
ipex flag is deprecated, will be removed in Accelerate v1.10. From 2.7.0, PyTorch has all needed optimizations for Intel CPU and XPU.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at tohoku-nlp/bert-base-japanese-whole-word-masking and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at tohoku-nlp/bert-base-japanese-whole-word-masking and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Number of training data: 1593
Number of validation data: 177
Number of batches in training dataloader: 13
Number of batches in validation dataloader: 2
Wrapping starts...
Number of batches in training dataloader: 13
Number of batches in validation dataloader: 2
Wrapping starts...
Device to use: cuda:0

Starting training...
Starting epoch 1
Device to use: cuda:1
Starting epoch 1
Processing batch 1/7...
Processing batch 1/7...
Processing batch 2/7...
Processing batch 2/7...
Processing batch 3/7...
Processing batch 3/7...
Processing batch 4/7...
Processing batch 4/7...
Processing batch 5/7...
Processing batch 5/7...
Processing batch 6/7...
Processing batch 6/7...
Processing batch 7/7...
Processing batch 7/7...

Epoch 1/1
  Train Loss: 0.5767
  Valid Loss: 0.3585

Training complete.

Running predictions on validation data...

Sample prediction results for a validation batch:
    logit_0   logit_1  pred_label  true_label
0 -0.947559  0.437343           1           1
1  0.163911 -0.242660           0           0
2 -0.900511  0.311918           1           1
3 -1.229280  0.196599           1           1
4  0.801448 -0.742624           0           0

Accuracy for this batch: 0.9597
[rank0]:[W821 08:41:59.691866964 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present,  but this warning has only been added since PyTorch 2.4 (function operator())

2node4GPU

2nodeで実行する場合は,両方のマシンでコマンドを実行しなければならない

  • mainのマシンではmachine_rank: 0
  • subのマシンではmachine_rank: 1
  • num_machines
    • 使用するノードの個数
  • num_processes
    • すべてのノードで使用するGPUの個数
  • mainのマシン
    • コンフィグファイルを作成
      configfile
        accelerate config
            --------------------------------------------------------------------------------------------------------
          In which compute environment are you running?
            This machine
            --------------------------------------------------------------------------------------------------------
          Which type of machine are you using?                                                                  
            multi-GPU
            How many different machines will you use (use more than 1 for multi-node training)? [1]: 2
            --------------------------------------------------------------------------------------------------------
      What is the rank of this machine?                                                                     
        0
        What is the IP address of the machine that will host the main process? 10.0.0.3
        What is the port you will use to communicate with the main process? 29500
        Are all the machines on the same local network? Answer `no` if nodes are on the cloud and/or on different network hosts [YES/no]: YES                                                                         
        Should distributed operations be checked while running for errors? This can avoid timeout issues but will be slower. [yes/NO]: NO                                                                             
        Do you wish to optimize your script with torch dynamo?[yes/NO]:NO
        Do you want to use DeepSpeed? [yes/NO]: NO
        Do you want to use FullyShardedDataParallel? [yes/NO]: NO
        Do you want to use Megatron-LM ? [yes/NO]: NO
        How many GPU(s) should be used for distributed training? [1]:4
        What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:all
        Would you like to enable numa efficiency? (Currently only supported on NVIDIA hardware). [yes/NO]: NO
        --------------------------------------------------------------------------------------------------------
      Do you wish to use mixed precision?
        no
        accelerate configuration saved at /root/.cache/huggingface/accelerate/default_config.yaml
      
      default_config.yaml

      compute_environment: LOCAL_MACHINE
      debug: false
      distributed_type: MULTI_GPU
      downcast_bf16: 'no'
      enable_cpu_affinity: false
      gpu_ids: all
      machine_rank: 0
      main_process_ip: 10.0.0.3
      main_process_port: 29500
      main_training_function: main
      mixed_precision: 'no'
      num_machines: 2
      num_processes: 4
      rdzv_backend: static
      same_network: true
      tpu_env: []
      tpu_use_cluster: false
      tpu_use_sudo: false
      use_cpu: false

  • subのマシン
    • subのマシンのaccelerateのコンフィグはmachine_rankのみを変更
      subマシンのdefault_config.yaml

      compute_environment: LOCAL_MACHINE
      debug: false
      distributed_type: MULTI_GPU
      downcast_bf16: 'no'
      enable_cpu_affinity: false
      gpu_ids: all
      machine_rank: 1
      main_process_ip: 10.0.0.3
      main_process_port: 29500
      main_training_function: main
      mixed_precision: 'no'
      num_machines: 2
      num_processes: 4
      rdzv_backend: static
      same_network: true
      tpu_env: []
      tpu_use_cluster: false
      tpu_use_sudo: false
      use_cpu: false

  • プログラムの実行
    • mainとsub実行を行う
    • 環境変数として,NCCLが使用するネットワークインタフェースを明示的に指定しないといけない
    • 今回の場合はmainマシンではeno1,subマシンではenp4s0
      • これらのネットワークインタフェースを通じて10.0.0.xのネットワークを構築している
      mainのip a

      2: eno1: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc mq state UP group default qlen 1000
          link/ether 0c:c4:7a:99:e1:2c brd ff:ff:ff:ff:ff:ff
          altname enp5s0
          inet 10.0.0.3/24 brd 10.0.0.255 scope global eno1
             valid_lft forever preferred_lft forever
          inet6 fe80::ec4:7aff:fe99:e12c/64 scope link
             valid_lft forever preferred_lft forever

      subのip a

      2: enp4s0: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc fq_codel state UP group default qlen 1000
          link/ether f8:32:e4:9d:47:87 brd ff:ff:ff:ff:ff:ff
          inet 10.0.0.6/24 brd 10.0.0.255 scope global enp4s0
             valid_lft forever preferred_lft forever
          inet6 fe80::fa32:e4ff:fe9d:4787/64 scope link
             valid_lft forever preferred_lft forever

    • コマンドの実行時にNCCL_SOCKET_IFNAE=インタフェース名を記述
    • mainで実行
      main
      NCCL_SOCKET_IFNAME=eno1 \
      accelerate \
      launch \
      classifier_acceraleter_BERT.py
      
    • subで実行
      sub
      NCCL_SOCKET_IFNAME=enp4s0 \
      accelerate \
      launch \
      classifier_acceraleter_BERT.py
      
  • 使用したいaccelerateのコンフィグを指定することも可能
    • mainで実行
      main
      NCCL_SOCKET_IFNAME=eno1 \
      accelerate \
      launch \
      --config_file ./default_config.yaml /
      classifier_acceraleter_BERT.py
      
    • subで実行
      sub
      NCCL_SOCKET_IFNAME=enp4s0 \
      accelerate \
      launch \
      --config_file ./default_config.yaml /
      classifier_acceraleter_BERT.py
      

終わり

  • accelerateを用いて,並列分散学習を動かしてみた
  • 今のところはデータ並列学習は動くが,モデル並列学習を動かすことはできなかった
  • deepspeedやMegatron-LMなどでモデル並列を試して見たいと思う

参考

Discussion