HOWTO: PyTorch Distributed Data Parallel (DDP)

PyTorch Distributed Data Parallel (DDP) is used to speed-up model training time by parallelizing training data across multiple identical model instances.

 

If your model fits on a single GPU and you have a large training set that is taking a long time to train, you can use DDP and request more GPUs to increase training speed.  The entire model is duplicated on each GPU and each training process receives a different subset of training data.  Model updates from each device are broadcast across devices, resulting in the same model on all devices.

 

For a complete overview with video tutorial and examples, see https://pytorch.org/tutorials/beginner/ddp_series_intro.html

 

Environment Setup

For running DDP at OSC, we recommend using a base PyTorch environment (in progress) or cloning a base PyTorch environment (in progress) and adding your project’s specific packages to it.

 

There are 6 main differences between DDP and single machine runs.  The following code examples are taken from https://github.com/pytorch/examples/tree/main/distributed/minGPT-ddp:

 DDP Setup Function

 DDP setup creates a process group and sets the local device.  This function is called toward the start of main.

def ddp_setup():
    init_process_group(backend="nccl")
    torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))

Trainer wraps model in DDP

from torch.nn.parallel import DistributedDataParallel as DDP

class Trainer:
    def __init__(self, trainer_config: TrainerConfig, model, optimizer,         train_dataset, test_dataset=None):
        ...
        self.model = DDP(self.model, device_ids=[self.local_rank])

Use DistributedSampler to load data (and set shuffle=False)

from torch.utils.data.distributed import DistributedSampler

class Trainer:
    ...
    def _prepare_dataloader(self, dataset: Dataset):
        return DataLoader(
            dataset,
            batch_size=self.config.batch_size,
            pin_memory=True,
            shuffle=False,
            num_workers=self.config.data_loader_workers,
            sampler=DistributedSampler(dataset)
        )

Destroy process group when done

def main():
    ...
    trainer.train()
    destroy_process_group()

Only save checkpoints where local_rank=0

class Trainer:
    ...
    def train(self):
        for epoch in range(self.epochs_run, self.config.max_epochs):
            epoch += 1
            self._run_epoch(epoch, self.train_loader, train=True)
            if self.local_rank == 0 and epoch % self.save_every == 0:
                self._save_snapshot(epoch)

Global vs local rank tracked separately

class Trainer:
    def __init__(self, trainer_config: TrainerConfig, model, optimizer,         train_dataset, test_dataset=None):
        self.config = trainer_config
        # set torchrun variables
        self.local_rank = int(os.environ["LOCAL_RANK"])
        self.global_rank = int(os.environ["RANK"])
        ...

 

Example Slurm Job Script using Srun Torchrun

#!/bin/bash
#SBATCH --job-name=multinode-example-minGPT
#SBATCH --nodes=2
#SBATCH --ntasks=2
#SBATCH --gpus-per-task=1
#SBATCH --cpus-per-task=4

nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )
nodes_array=($nodes)
head_node=${nodes_array[0]}
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)

echo Node IP: $head_node_ip
export LOGLEVEL=INFO

ml miniconda3/24.1.2-py310
conda activate minGPT-ddp

srun torchrun \
--nnodes 2 \
--nproc_per_node 1 \
--rdzv_id $RANDOM \
--rdzv_backend c10d \
--rdzv_endpoint $head_node_ip:29500 \
/path/to/examples/distributed/minGPT-ddp/mingpt/main.py