HOWTO: PyTorch Fully Sharded Data Parallel (FSDP)

PyTorch Fully Sharded Data Parallel (FSDP) is used to speed-up model training time by parallelizing training data as well as sharding model parameters, optimizer states, and gradients across multiple pytorch instances.

 

If your model does not fit on a single GPU, you can use FSDP and request more GPUs to reduce the memory footprint for each GPU.  The model parameters are split between the GPUs and each training process receives a different subset of training data.  Model updates from each device are broadcast across devices, resulting in the same model on all devices.

 

For a complete overview with examples, see https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html

 

Environment Setup

For running FSDP at OSC, we recommend using a base PyTorch environment or cloning a base PyTorch environment and adding your project’s specific packages to it.

 

There are 6 main differences between FSDP and single machine runs:

 FSDP Setup Function

 FSDP setup creates a process group and sets the local device.  This function is called toward the start of main.

def fsdp_setup():
    init_process_group(backend="nccl")
    torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))

Trainer wraps model in FSDP

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

class Trainer:
    def __init__(self, trainer_config: TrainerConfig, model, optimizer,         train_dataset, test_dataset=None):
        ...
        model = FSDP(model,
            auto_wrap_policy=t5_auto_wrap_policy,
            mixed_precision=mixed_precision_policy,
            sharding_strategy=fsdp_config.sharding_strategy,
            device_id=torch.cuda.current_device(),
            limit_all_gathers=fsdp_config.limit_all_gathers)

Use DistributedSampler to load data

from torch.utils.data.distributed import DistributedSampler

sampler1 = DistributedSampler(dataset1, rank=rank, num_replicas=world_size, shuffle=True)
train_kwargs = {'batch_size': train_config.batch_size_training, 'sampler': sampler1}
cuda_kwargs = {'num_workers': train_config.num_workers_dataloader,
               'pin_memory': True,
               'shuffle': False}
train_kwargs.update(cuda_kwargs)
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)

Destroy process group after training/validation and any post-processing has completed

def cleanup():
    dist.destroy_process_group()
...

Only save checkpoints where local_rank=0

if fsdp_config.fsdp_activation_checkpointing and local_rank == 0:
    policies.apply_fsdp_checkpointing(model)

Global vs local rank tracked separately

class Trainer:
    def __init__(self, trainer_config: TrainerConfig, model, optimizer,         train_dataset, test_dataset=None):
        self.config = trainer_config
        # set torchrun variables
        self.local_rank = int(os.environ["LOCAL_RANK"])
        self.global_rank = int(os.environ["RANK"])
        ...

 

Example Slurm Job Script using Srun Torchrun

#!/bin/bash
#SBATCH --job-name=fsdp-t5-multinode
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=1
#SBATCH --gpus-per-task=4
#SBATCH --cpus-per-task=96

nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )
nodes_array=($nodes)
head_node=${nodes_array[0]}
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)

echo Node IP: $head_node_ip
export LOGLEVEL=INFO

ml miniconda3/24.1.2-py310
conda activate fsdp

srun torchrun \
--nnodes 2 \
--nproc_per_node 1 \
--rdzv_id $RANDOM \
--rdzv_backend c10d \
--rdzv_endpoint $head_node_ip:29500 \
/path/to/examples/distributed/T5-fsdp/fsdp_t5.py