Getting Started with Distributed Data Parallel in PyTorch: A Beginner's Guide

Learn Multi GPU Training with DDP: Step by Step Tutorial and Tips for Deep Learning Scaling

Posted by JacksonCakes on August 20, 2023

Introduction

With the launch of cutting-edge models like ChatGPT, the world has been witnessing a remarkable shift towards the development of Large Language Models (LLMs). Take, for example, Meta’s recent release of LLama 2, an open-source LLM with a staggering 70 billion parameters. Now, even if you were to harness the most powerful GPU available today, the A100 with an 80GB memory capacity, this gigantic model simply wouldn’t fit. Given a 16-bit floating-point precision (fp16), the computation stands at:

$ 70,000,000,000×2 bytes=140 GB $

That means it would require 140GB of memory to fit this model on a GPU – mind blow! So, how do industry giants like Meta manage to train models of such size? The magic word here is distributed training.

The concept of distributed training is elegantly straightforward: if a model can’t fit into a single GPU, why not divide it across multiple GPUs? However, the application of distributed training extends beyond just accommodating large models. Often, even when a model size fits within the GPU memory, we train it using data in batches. A higher batch size can enhance training efficiency but also increases memory consumption. We could now overcome the memory bottleneck by strategically splitting these large batches into smaller ones that fit within a single GPU’s capacity, and then training across a cluster of GPUs in parallel. When it comes to distributed training, parallelism plays a key role, and there are two primary methods to consider:

Model Parallelism:

  • Used when the model is so large that it cannot fit into a single GPU, causing GPU memory to be the bottleneck.
  • Involves dividing different parts of the model across multiple GPUs.
  • This method can be challenging to implement as it requires careful synchronization and communication between various sections of the model (also, its tailored specific to each model architecture), making it a complex solution.

Data Parallelism (MAIN):

  • Used when GPU compute is the constraint, meaning that the model fits into a single GPU, but training is too slow. A way to speed up the training is to increase the batch size, allowing the model to process more data simultaneously. While a larger batch size can lead to increased memory consumption, we could use data parallelism to distribute the batches of data across multiple GPUs.
  • Alternatively, we could also decrease the batch size so that model fits into memory and applies data parallelism.
  • Involves dividing the training data into batches and distributing these batches across several GPUs.
  • Each GPU maintains a full copy of the model and calculates the gradients using its particular subset of the data.
  • The gradients are then averaged across all GPUs and applied to update the model.

In this blog, our primary focus will be on data parallelism, as it’s often more relevant to typical industry applications and is generally easier to implement. Model parallelism, though an essential concept, is less commonly used, particularly when real-world models are often not that large. We will come back into this complex topic in a future section.

Data Parallel (DP) vs Distributed Data Parallel (DDP)

When we talk about data parallelism, there are two methods that we can used. DP:

  • Single-process, multi-thread, and only works on a single machine (but multi-gpu).
  • Utilizing threads instead of processes. Threads vs Processes
  • Could be slower than DDP in a single machine due to python GIL. which prevents threads from executing in true parallel.
  • Simple to implement. You could simply wrap your model with:
    1
    2
    3
    
    model = NeuralNetwork()
    # make it DP
    model = nn.DataParallel(model)
    

    DDP:

  • Multi-process and works on both single and multi-node GPU.
  • The model is broadcast at DDP construction time instead of in every forward pass (DP), which also helps to speed up training
  • More complex to set up, especially across multiple nodes.

How to Start DDP with PyTorch?

Before diving into an example of how to convert a standard PyTorch training script to Distributed Data Parallel (DDP), it’s essential to understand a few key concepts:

  • World Size: This refers to the total number of processes in the distributed group. In the context of DDP, it represents the total count of GPUs across all machines.
  • Rank: Each process in distributed training has a unique identifier called a rank. Rank 0 is typically considered the “master” GPU, responsible for coordinating the others. The Global Rank refers to the rank numbering from 0, 1, 2, 3, … up to the total number of GPUs (non-overlapping). The Local Rank refers to the GPU rank within a single machine. For example, it could be rank 0, 1, 2, or 3 for a four-GPU machine, and similarly for another four-GPU machine.
  • Backend: PyTorch supports various backends for communication between processes, such as NCCL and GLOO. Often, NCCL is used based on the rule of thumb.
  • Data Sampler: Utilizing a distributed sampler like torch.utils.data.distributed.DistributedSampler ensures that the dataset is partitioned appropriately across the different GPUs, avoiding identical processing of data portions by different GPUs.

Recipe for Converting Your Script into a Distributed Training Script

Step 1: Import required libraries:

1
2
3
4
5
6
7
8
import os
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group

Step 2: Initialize a process group and set the current GPU device (so that it know which GPU to be used):

1
init_process_group(backend="nccl") torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))

Step 3: Get the local rank through the environment variable:

1
local_rank = int(os.environ["LOCAL_RANK"])

Step 4: Set your model, source, and target to the local rank:

1
2
3
4
5
6
7
8
# Normal training script 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
source = source.to(device)
target = target.to(device)

# DDP
DDP model = model.to(local_rank) # Use local_rank from the previous step source = source.to(local_rank) target = target.to(local_rank)

Step 5: Wrap your model with PyTorch DDP:

1
self.model = DDP(self.model, device_ids=[local_rank])

Step 6: Ensure that the data processed by each GPU is unique:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
train_set = MyTrainDataset(32000)  # load your dataset
train_sampler = DistributedSampler(train_set)
train_loader = DataLoader(
        dataset,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=False,
        sampler=train_sampler
    )
// In your training loop
def train():
	for epoch in range(self.epochs_run, max_epochs):
		// ensures that the data is split and shuffled differently among the workers at each new epoch
		train_sampler.set_epoch(epoch)
		...

Additional Utility Functions

Saving Model Checkpoints: To prevent loss of progress if training fails, you can save model checkpoints every X number of epochs. Note: Save on checkpoints only when local_rank == 0 to avoid redundancy

1
2
3
4
5
6
7
8
9
10
def _save_snapshot(self, epoch):
	snapshot = {
		"MODEL_STATE": self.model.module.state_dict(),
		"EPOCHS_RUN": epoch,
	}
	torch.save(snapshot, self.snapshot_path)
	print(f"Epoch {epoch} | Training snapshot saved at {self.snapshot_path}")

if local_rank == 0:
	save_snapshot(epoch)

Resuming Training from Checkpoints: You can resume training from a saved checkpoint.

1
2
3
4
5
6
def _load_snapshot(self, snapshot_path):
        loc = f"cuda:{self.local_rank}"
        snapshot = torch.load(snapshot_path, map_location=loc)
        self.model.load_state_dict(snapshot["MODEL_STATE"])
        self.epochs_run = snapshot["EPOCHS_RUN"]
        print(f"Resuming training from snapshot at Epoch {self.epochs_run}")

Now you have your model wrapped in DDP, ready for efficient distributed training. Here is the full example of training scripts.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group

def ddp_setup():
    init_process_group(backend="nccl")
    torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))

class MyTrainDataset(Dataset):
    def __init__(self, size):
        self.size = size
        self.data = [(torch.rand(20), torch.rand(1)) for _ in range(size)]

    def __len__(self):
        return self.size
    
    def __getitem__(self, index):
        return self.data[index]
        
class Trainer:
    def __init__(
        self,
        model: torch.nn.Module,
        train_data: DataLoader,
        optimizer: torch.optim.Optimizer,
        save_every: int,
        snapshot_path: str,
    ) -> None:
        self.local_rank = int(os.environ["LOCAL_RANK"])
        self.global_rank = int(os.environ["RANK"])
        self.model = model.to(self.local_rank)
        self.train_data = train_data
        self.optimizer = optimizer
        self.save_every = save_every
        self.epochs_run = 0
        self.snapshot_path = snapshot_path
        '''if os.path.exists(snapshot_path):
            print("Loading snapshot")
            self._load_snapshot(snapshot_path)'''

        self.model = DDP(self.model, device_ids=[self.local_rank])

    def _load_snapshot(self, snapshot_path):
        loc = f"cuda:{self.local_rank}"
        snapshot = torch.load(snapshot_path, map_location=loc)
        self.model.load_state_dict(snapshot["MODEL_STATE"])
        self.epochs_run = snapshot["EPOCHS_RUN"]
        print(f"Resuming training from snapshot at Epoch {self.epochs_run}")

    def _run_batch(self, source, targets):
        self.optimizer.zero_grad()
        output = self.model(source)
        loss = F.cross_entropy(output, targets)
        loss.backward()
        self.optimizer.step()

    def _run_epoch(self, epoch):
        b_sz = len(next(iter(self.train_data))[0])
        print(f"[GPU{self.global_rank}] Epoch {epoch} | Batchsize: {b_sz} | Steps: {len(self.train_data)}")
        self.train_data.sampler.set_epoch(epoch)
        for source, targets in self.train_data:
            source = source.to(self.local_rank)
            targets = targets.to(self.local_rank)
            self._run_batch(source, targets)

    def _save_snapshot(self, epoch):
        snapshot = {
            "MODEL_STATE": self.model.module.state_dict(),
            "EPOCHS_RUN": epoch,
        }
        torch.save(snapshot, self.snapshot_path)
        print(f"Epoch {epoch} | Training snapshot saved at {self.snapshot_path}")

    def train(self, max_epochs: int):
        for epoch in range(self.epochs_run, max_epochs):
            self._run_epoch(epoch)
            if self.local_rank == 0 and epoch % self.save_every == 0:
                self._save_snapshot(epoch)


def load_train_objs():
    train_set = MyTrainDataset(32000)  # load your dataset
    model = torch.nn.Linear(20, 1)  # load your model
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
    return train_set, model, optimizer


def prepare_dataloader(dataset: Dataset, batch_size: int):
    return DataLoader(
        dataset,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=False,
        sampler=DistributedSampler(dataset)
    )

def main(save_every: int, total_epochs: int, batch_size: int, snapshot_path: str = "snapshot.pt"):
    ddp_setup()
    dataset, model, optimizer = load_train_objs()
    train_data = prepare_dataloader(dataset, batch_size)
    trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path)
    trainer.train(total_epochs)
    destroy_process_group()


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description='simple distributed training job')
    parser.add_argument('total_epochs', type=int, help='Total epochs to train the model')
    parser.add_argument('save_every', type=int, help='How often to save a snapshot')
    parser.add_argument('--batch_size', default=32000, type=int, help='Input batch size on each device (default: 32)')
    args = parser.parse_args()
    
    main(args.save_every, args.total_epochs, args.batch_size)

To execute the scripts, pytorch provide a simple command-line utility to launch distributed training across various local or remote environments.

Single Node, Multi-GPU

If you have multiple GPUs on a single machine and want to run your training script across all of them, you can use torchrun with the --nproc_per_node option. For example, to run on 4 GPUs on the local machine, you would use:

1
torchrun --nproc_per_node=4 ddp_script.py

This would launch 4 instances of my_training_script.py, one for each GPU.

Multi-Node, Multi-GPU

For multi-node, multi-GPU training, you would need to specify the nodes (machines) and the number of GPUs per node. Here’s an example: On master node (main machine)

1
2
3
4
5
6
torchrun --nproc_per_node=4 \
         --nnodes=2 \
         --node_rank=0 \
         --master_addr="192.168.1.2" \
         --master_port=1234 \
         ddp_script.py 1000 10

On other node (modify its node_rank)

1
2
3
4
5
6
torchrun --nproc_per_node=4 \
         --nnodes=2 \
         --node_rank=1 \
         --master_addr="192.168.1.2" \
         --master_port=1234 \
         ddp_script.py 1000 10

Here’s what the parameters mean:

  • --nproc_per_node=4: Number of GPUs per node (e.g., 4 GPUs on each machine).
  • --nnodes=2: Total number of nodes (machines).
  • --node_rank=0: Rank of the current node. You would run this command on each node, changing the node_rank accordingly (e.g., 0 on the first machine, 1 on the second machine, etc.).
  • --master_addr="192.168.1.2": IP address of the master node.
  • --master_port=1234: Port to communicate with the master node. Great, now you can launch your own distributed training!

Troubleshooting

  • Ensure all the machines can communicate with each other. Use ping your_machines_ip or using nc <IP_ADDRESS_OF_FIRST_NODE> <PORT>
  • If you are using NCCL as backend, ensure that each machine use the correct network interface name. First, check each machine network interface name using
    1
    
    ip addr	
    

    it will output:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
docker0: flags=4163<UP,BROADCAST,RUNNING,MULTICAST>  mtu 1500
        inet 172.17.0.1  netmask 255.255.0.0  broadcast 172.17.255.255
        ...

ens5: flags=4163<UP,BROADCAST,RUNNING,MULTICAST>  mtu 9001
        inet 192.168.1.2  netmask 255.255.224.0  broadcast 172.31.95.255
        ...

lo: flags=73<UP,LOOPBACK,RUNNING>  mtu 65536
        inet 127.0.0.1  netmask 255.0.0.0
        ...

veth526c8fe: flags=4163<UP,BROADCAST,RUNNING,MULTICAST>  mtu 1500
        inet6 fe80::44c:7bff:fe80:f02b  prefixlen 64  scopeid 0x20<link>
        ...

Select the one with your ip address, which is ens5. Now set the network interface name for this machine via:

1
export NCCL_SOCKET_IFNAME=ens5

Repeat the same process for all the machines.