PyTorch Distributed Data Parallel (DDP) is used to speed-up model training time by parallelizing training data across multiple identical model instances.
If your model fits on a single GPU and you have a large training set that is taking a long time to train, you can use DDP and request more GPUs to increase training speed. The entire model is duplicated on each GPU and each training process receives a different subset of training data. Model updates from each device are broadcast across devices, resulting in the same model on all devices.
For a complete overview with video tutorial and examples, see https://pytorch.org/tutorials/beginner/ddp_series_intro.html
Environment Setup
For running DDP at OSC, we recommend using a base PyTorch environment or cloning a base PyTorch environment and adding your project’s specific packages to it.
There are 6 main differences between DDP and single machine runs. The following code examples are taken from https://github.com/pytorch/examples/tree/main/distributed/minGPT-ddp:
DDP Setup Function
DDP setup creates a process group and sets the local device. This function is called toward the start of main.
def ddp_setup(): init_process_group(backend="nccl") torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
Trainer wraps model in DDP
from torch.nn.parallel import DistributedDataParallel as DDP class Trainer: def __init__(self, trainer_config: TrainerConfig, model, optimizer, train_dataset, test_dataset=None): ... self.model = DDP(self.model, device_ids=[self.local_rank])
Use DistributedSampler to load data (and set shuffle=False)
from torch.utils.data.distributed import DistributedSampler class Trainer: ... def _prepare_dataloader(self, dataset: Dataset): return DataLoader( dataset, batch_size=self.config.batch_size, pin_memory=True, shuffle=False, num_workers=self.config.data_loader_workers, sampler=DistributedSampler(dataset) )
Destroy process group when done
def main(): ... trainer.train() destroy_process_group()
Only save checkpoints where local_rank=0
class Trainer: ... def train(self): for epoch in range(self.epochs_run, self.config.max_epochs): epoch += 1 self._run_epoch(epoch, self.train_loader, train=True) if self.local_rank == 0 and epoch % self.save_every == 0: self._save_snapshot(epoch)
Global vs local rank tracked separately
class Trainer: def __init__(self, trainer_config: TrainerConfig, model, optimizer, train_dataset, test_dataset=None): self.config = trainer_config # set torchrun variables self.local_rank = int(os.environ["LOCAL_RANK"]) self.global_rank = int(os.environ["RANK"]) ...
Example Slurm Job Script using Srun Torchrun
#!/bin/bash #SBATCH --job-name=multinode-example-minGPT #SBATCH --nodes=2 #SBATCH --ntasks=2 #SBATCH --gpus-per-task=1 #SBATCH --cpus-per-task=4 nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) ) nodes_array=($nodes) head_node=${nodes_array[0]} head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) echo Node IP: $head_node_ip export LOGLEVEL=INFO ml miniconda3/24.1.2-py310 conda activate minGPT-ddp srun torchrun \ --nnodes 2 \ --nproc_per_node 1 \ --rdzv_id $RANDOM \ --rdzv_backend c10d \ --rdzv_endpoint $head_node_ip:29500 \ /path/to/examples/distributed/minGPT-ddp/mingpt/main.py