【PyTorch】How to use multiple GPU in PyTorch
This time, I'll write up about how to use multiple GPU in pytorch.
There are two way to use multiple GPU:
- DataParallel
- DistributedDataParallel (DDP)
The details are explained below.
1. Using DataParallel
This is a simpler option and works well for models that fit comfortably in memory on each GPU.
This splits the input across the GPUs, performs computations in parallel, and gathers the results back. This is simple to implement but less efficient due to the need to copy data to the main GPU.
・How to use
Write the below before the model = model.to(device)
. Over.
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
model = model.to(device)
・Example
import torch
import torch.nn as nn
model = YourModel()
# Wrap the model with DataParallel
if torch.cuda.device_count() > 1:
print(f"Using {torch.cuda.device_count()} GPUs!")
model = nn.DataParallel(model)
# Move the model to GPU
model = model.to('cuda')
# Now proceed with your training loop
for data in dataloader:
inputs, labels = data
inputs, labels = inputs.to('cuda'), labels.to('cuda')
# Forward pass
outputs = model(inputs)
loss = criterion(outputs, labels)
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()
This is an easy and useful way. However, it's less efficient for larger models or when you need more granular control over parallelism.
2. Using DistributedDataParallel (DDP)
DistributedDataParallel is more efficient and is recommended for multi-GPU setups, especially for large-scale training. Even you need to set up a distributed training environment, which provide more better performance.
・Example
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
def setup(rank, world_size):
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def train(rank, world_size):
setup(rank, world_size)
model = MyModel().to(rank)
model = DDP(model, device_ids=[rank])
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
train_loader = DataLoader(dataset=train_dataset, sampler=train_sampler, batch_size=32)
for data, target in train_loader:
data, target = data.to(rank), target.to(rank)
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
cleanup()
if __name__ == "__main__":
world_size = torch.cuda.device_count()
mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)
3. Summary
This time, I introduced two ways for using mutiple GPU in pytorch.
For a more efficient multi-GPU training, prefer using DistributedDataParallel whenever possible.
Discussion