PyTorch Fully Sharded Data Parallel (FSDP) is used to speed-up model training time by parallelizing training data as well as sharding model parameters, optimizer states, and gradients across multiple pytorch instances.
If your model does not fit on a single GPU, you can use FSDP and request more GPUs to reduce the memory footprint for each GPU. The model parameters are split between the GPUs and each training process receives a different subset of training data. Model updates from each device are broadcast across devices, resulting in the same model on all devices.
For a complete overview with examples, see https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html
Environment Setup
For running FSDP at OSC, we recommend using a base PyTorch environment or cloning a base PyTorch environment and adding your project’s specific packages to it.
There are 6 main differences between FSDP and single machine runs:
FSDP Setup Function
FSDP setup creates a process group and sets the local device. This function is called toward the start of main.
def fsdp_setup(): init_process_group(backend="nccl") torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
Trainer wraps model in FSDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP class Trainer: def __init__(self, trainer_config: TrainerConfig, model, optimizer, train_dataset, test_dataset=None): ... model = FSDP(model, auto_wrap_policy=t5_auto_wrap_policy, mixed_precision=mixed_precision_policy, sharding_strategy=fsdp_config.sharding_strategy, device_id=torch.cuda.current_device(), limit_all_gathers=fsdp_config.limit_all_gathers)
Use DistributedSampler to load data
from torch.utils.data.distributed import DistributedSampler sampler1 = DistributedSampler(dataset1, rank=rank, num_replicas=world_size, shuffle=True) train_kwargs = {'batch_size': train_config.batch_size_training, 'sampler': sampler1} cuda_kwargs = {'num_workers': train_config.num_workers_dataloader, 'pin_memory': True, 'shuffle': False} train_kwargs.update(cuda_kwargs) train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
Destroy process group after training/validation and any post-processing has completed
def cleanup(): dist.destroy_process_group() ...
Only save checkpoints where local_rank=0
if fsdp_config.fsdp_activation_checkpointing and local_rank == 0: policies.apply_fsdp_checkpointing(model)
Global vs local rank tracked separately
class Trainer: def __init__(self, trainer_config: TrainerConfig, model, optimizer, train_dataset, test_dataset=None): self.config = trainer_config # set torchrun variables self.local_rank = int(os.environ["LOCAL_RANK"]) self.global_rank = int(os.environ["RANK"]) ...
Example Slurm Job Script using Srun Torchrun
#!/bin/bash #SBATCH --job-name=fsdp-t5-multinode #SBATCH --nodes=2 #SBATCH --ntasks-per-node=1 #SBATCH --gpus-per-task=4 #SBATCH --cpus-per-task=96 nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) ) nodes_array=($nodes) head_node=${nodes_array[0]} head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) echo Node IP: $head_node_ip export LOGLEVEL=INFO ml miniconda3/24.1.2-py310 conda activate fsdp srun torchrun \ --nnodes 2 \ --nproc_per_node 1 \ --rdzv_id $RANDOM \ --rdzv_backend c10d \ --rdzv_endpoint $head_node_ip:29500 \ /path/to/examples/distributed/T5-fsdp/fsdp_t5.py