Distributed Data Parallel¶
Warning
The implementation of torch.nn.parallel.DistributedDataParallel
evolves over time. This design note is written based on the state as of v1.4.
torch.nn.parallel.DistributedDataParallel
(DDP) transparently performs
distributed data parallel training. This page describes how it works and reveals
implementation details.
Example¶
Let us start with a simple torch.nn.parallel.DistributedDataParallel
example. This example uses a torch.nn.Linear
as the local model, wraps
it with DDP, and then runs one forward pass, one backward pass, and an optimizer
step on the DDP model. After that, parameters on the local model will be
updated, and all models on different processes should be exactly the same.
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
import os
from torch.nn.parallel import DistributedDataParallel as DDP
def example(rank, world_size):
# create default process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
# create local model
model = nn.Linear(10, 10).to(rank)
# construct DDP model
ddp_model = DDP(model, device_ids=[rank])
# define loss function and optimizer
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
# forward pass
outputs = ddp_model(torch.randn(20, 10).to(rank))
labels = torch.randn(20, 10).to(rank)
# backward pass
loss_fn(outputs, labels).backward()
# update parameters
optimizer.step()
def main():
world_size = 2
mp.spawn(example,
args=(world_size,),
nprocs=world_size,
join=True)
if __name__=="__main__":
# Environment variables which need to be
# set when using c10d's default "env"
# initialization mode.
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
main()
DDP works with TorchDynamo. When used with TorchDynamo, apply the DDP model wrapper
before compiling the model, such that torchdynamo can apply DDPOptimizer
(graph-break optimizations) based on DDP bucket sizes. (See TorchDynamo DDPOptimizer for more information.)
ddp_model = DDP(model, device_ids=[rank])
ddp_model = torch.compile(ddp_model)
Internal Design¶
This section reveals how it works under the hood of
torch.nn.parallel.DistributedDataParallel
by diving into details of
every step in one iteration.
Prerequisite: DDP relies on c10d
ProcessGroup
for communications. Hence, applications must createProcessGroup
instances before constructing DDP.Construction: The DDP constructor takes a reference to the local module, and broadcasts
state_dict()
from the process with rank 0 to all other processes in the group to make sure that all model replicas start from the exact same state. Then, each DDP process creates a localReducer
, which later will take care of the gradients synchronization during the backward pass. To improve communication efficiency, theReducer
organizes parameter gradients into buckets, and reduces one bucket at a time. Bucket size can be configured by setting the bucket_cap_mb argument in DDP constructor. The mapping from parameter gradients to buckets is determined at the construction time, based on the bucket size limit and parameter sizes. Model parameters are allocated into buckets in (roughly) the reverse order ofModel.parameters()
from the given model. The reason for using the reverse order is because DDP expects gradients to become ready during the backward pass in approximately that order. The figure below shows an example. Note that, thegrad0
andgrad1
are inbucket1
, and the other two gradients are inbucket0
. Of course, this assumption might not always be true, and when that happens it could hurt DDP backward speed as theReducer
cannot kick off the communication at the earliest possible time. Besides bucketing, theReducer
also registers autograd hooks during construction, one hook per parameter. These hooks will be triggered during the backward pass when the gradient becomes ready.Forward Pass: The DDP takes the input and passes it to the local model, and then analyzes the output from the local model if
find_unused_parameters
is set toTrue
. This mode allows running backward on a subgraph of the model, and DDP finds out which parameters are involved in the backward pass by traversing the autograd graph from the model output and marking all unused parameters as ready for reduction. During the backward pass, theReducer
would only wait for unready parameters, but it would still reduce all buckets. Marking a parameter gradient as ready does not help DDP skip buckets as for now, but it will prevent DDP from waiting for absent gradients forever during the backward pass. Note that traversing the autograd graph introduces extra overheads, so applications should only setfind_unused_parameters
toTrue
when necessary.Backward Pass: The
backward()
function is directly invoked on the lossTensor
, which is out of DDP’s control, and DDP uses autograd hooks registered at construction time to trigger gradients synchronizations. When one gradient becomes ready, its corresponding DDP hook on that grad accumulator will fire, and DDP will then mark that parameter gradient as ready for reduction. When gradients in one bucket are all ready, theReducer
kicks off an asynchronousallreduce
on that bucket to calculate mean of gradients across all processes. When all buckets are ready, theReducer
will block waiting for allallreduce
operations to finish. When this is done, averaged gradients are written to theparam.grad
field of all parameters. So after the backward pass, the grad field on the same corresponding parameter across different DDP processes should be the same.Optimizer Step: From the optimizer’s perspective, it is optimizing a local model. Model replicas on all DDP processes can keep in sync because they all start from the same state and they have the same averaged gradients in every iteration.
Note
DDP requires Reducer
instances on all processes to invoke allreduce
in exactly the same order, which is done by always running allreduce
in the bucket index order instead of actual bucket ready order. Mismatched
allreduce
order across processes can lead to wrong results or DDP backward
hang.
Implementation¶
Below are pointers to the DDP implementation components. The stacked graph shows the structure of the code.
ProcessGroup¶
ProcessGroup.hpp: contains the abstract API of all process group implementations. The
c10d
library provides 3 implementations out of the box, namely, ProcessGroupGloo, ProcessGroupNCCL, and ProcessGroupMPI.DistributedDataParallel
usesProcessGroup::broadcast()
to send model states from the process with rank 0 to others during initialization andProcessGroup::allreduce()
to sum gradients.Store.hpp: assists the rendezvous service for process group instances to find each other.
DistributedDataParallel¶
distributed.py: is the Python entry point for DDP. It implements the initialization steps and the
forward
function for thenn.parallel.DistributedDataParallel
module which call into C++ libraries. Its_sync_param
function performs intra-process parameter synchronization when one DDP process works on multiple devices, and it also broadcasts model buffers from the process with rank 0 to all other processes. The inter-process parameter synchronization happens inReducer.cpp
.comm.h: implements the coalesced broadcast helper function which is invoked to broadcast model states during initialization and synchronize model buffers before the forward pass.
reducer.h: provides the core implementation for gradient synchronization in the backward pass. It has three entry point functions:
Reducer
: The constructor is called indistributed.py
which registersReducer::autograd_hook()
to gradient accumulators.autograd_hook()
function will be invoked by the autograd engine when a gradient becomes ready.prepare_for_backward()
is called at the end of DDP forward pass indistributed.py
. It traverses the autograd graph to find unused parameters whenfind_unused_parameters
is set toTrue
in DDP constructor.
TorchDynamo DDPOptimizer¶
DDP’s performance advantage comes from overlapping allreduce collectives with computations during backwards. AotAutograd prevents this overlap when used with TorchDynamo for compiling a whole forward and whole backward graph, because allreduce ops are launched by autograd hooks _after_ the whole optimized backwards computation finishes.
TorchDynamo’s DDPOptimizer helps by breaking the forward graph at the logical boundaries of DDP’s allreduce buckets during backwards. Note: the goal is to break the graph during backwards, and the simplest implementation is to break the forward graphs and then call AotAutograd and compilation on each section. This allows DDP’s allreduce hooks to fire in-between sections of backwards, and schedule communications to overlap with compute.
See this blog post for a more in-depth explanation and experimental results, or read the docs and code at torch/_dynamo/optimizations/distributed.py
To Debug DDPOptimizer, set torch._dynamo.config.log_level to DEBUG (for full graph dumps) or INFO (for basic info about bucket boundaries). To disable DDPOptimizer, set torch._dynamo.config.optimize_ddp=False. DDP and TorchDynamo should still work correctly without DDPOptimizer, but with performance degradation.