FullyShardedDataParallel¶
- class torch.distributed.fsdp.FullyShardedDataParallel(module, process_group=None, sharding_strategy=None, cpu_offload=None, auto_wrap_policy=None, backward_prefetch=BackwardPrefetch.BACKWARD_PRE, mixed_precision=None, ignored_modules=None, param_init_fn=None, device_id=None, sync_module_states=False, forward_prefetch=False, limit_all_gathers=False, use_orig_params=False, ignored_parameters=None)[source]¶
A wrapper for sharding Module parameters across data parallel workers. This is inspired by Xu et al. as well as the ZeRO Stage 3 from DeepSpeed. FullyShardedDataParallel is commonly shortened to FSDP.
Example:
>>> import torch >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> torch.cuda.set_device(device_id) >>> sharded_module = FSDP(my_module) >>> optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001) >>> x = sharded_module(x, y=3, z=torch.Tensor([1])) >>> loss = x.sum() >>> loss.backward() >>> optim.step()
Warning
The optimizer must be initialized after the module has been wrapped, since FSDP will shard parameters in-place and this will break any previously initialized optimizers.
Warning
If the destination CUDA device has ID
dev_id
, either (1)module
should already be placed on that device, (2) the device should be set usingtorch.cuda.set_device(dev_id)
, or (3)dev_id
should be passed into thedevice_id
constructor argument. This FSDP instance’s compute device will be that destination device. For (1) and (3), the FSDP initialization always occurs on GPU. For (2), the FSDP initialization happens onmodule
‘s current device, which may be CPU.Warning
FSDP currently does not support gradient accumulation outside
no_sync()
when using CPU offloading. Trying to do so yields incorrect results since FSDP will use the newly-reduced gradient instead of accumulating with any existing gradient.Warning
Changing the original parameter variable names after construction will lead to undefined behavior.
Warning
Passing in sync_module_states=True flag requires module to be put on GPU, or to use
device_id
argument to specify a CUDA device that FSDP will move module to. This is becausesync_module_states=True
requires GPU communication.Warning
As of PyTorch 1.12, FSDP only offers limited support for shared parameters (for example, setting one
Linear
layer’s weight to another’s). In particular, modules that share parameters must be wrapped as part of the same FSDP unit. If enhanced shared parameter support is needed for your use case, please ping https://github.com/pytorch/pytorch/issues/77724Note
Inputs into FSDP
forward
function will be moved to compute device (same device FSDP module is on) before runningforward
, so user does not have to manually move inputs from CPU -> GPU.- Parameters:
module (nn.Module) – This is the module to be wrapped with FSDP.
process_group (Optional[Union[ProcessGroup, Tuple[ProcessGroup, ProcessGroup]]]) – Optional[Union[ProcessGroup, Tuple[ProcessGroup, ProcessGroup]]] This is the process group used for collective communications and the one over which the model is sharded. For hybrid sharding strategies such as
ShardingStrategy.HYBRID_SHARD
users can pass in a tuple of process groups representing the groups to shard and replicate across, respectively.sharding_strategy (Optional[ShardingStrategy]) – This configures the sharding strategy used by FSDP, which may trade off memory saving and communication overhead. See
ShardingStrategy
for details. (Default:FULL_SHARD
)cpu_offload (Optional[CPUOffload]) – This configures CPU offloading. If this is set to
None
, then no CPU offloading happens. SeeCPUOffload
for details. (Default:None
)auto_wrap_policy (Optional[Union[Callable[[nn.Module, bool, int], bool], _FSDPPolicy]]) –
This is either
None
, an_FSDPPolicy
, or a callable of a fixed signature. If it isNone
, thenmodule
is wrapped with only a top-level FSDP instance without any nested wrapping. If it is an_FSDPPolicy
, then the wrapping follows the given policy.ModuleWrapPolicy
intorch.distributed.fsdp.wrap.py
is an example. If it is a callable, then it should take in three argumentsmodule: nn.Module
,recurse: bool
, andnonwrapped_numel: int
and should return abool
specifying whether the passed-inmodule
should be wrapped ifrecurse=False
or if the traversal should continue down the subtree ifrecurse=True
. Additional custom arguments may be added to the callable. Thesize_based_auto_wrap_policy
intorch.distributed.fsdp.wrap.py
gives an example callable that wraps a module if the parameters in its subtree exceed 100M numel. A good practice is to print the model after wrapping and adjust as needed.Example:
>>> def custom_auto_wrap_policy( >>> module: nn.Module, >>> recurse: bool, >>> nonwrapped_numel: int, >>> # Additional custom arguments >>> min_num_params: int = int(1e8), >>> ) -> bool: >>> return nonwrapped_numel >= min_num_params >>> # Configure a custom `min_num_params` >>> my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
backward_prefetch (Optional[BackwardPrefetch]) – This configures explicit backward prefetching of all-gathers. See
BackwardPrefetch
for details. (Default:BACKWARD_PRE
)mixed_precision (Optional[MixedPrecision]) – This configures native mixed precision for FSDP. If this is set to
None
, then no mixed precision is used. Otherwise, parameter, buffer, and gradient reduction dtypes can be set. SeeMixedPrecision
for details. (Default:None
)ignored_modules (Optional[Iterable[torch.nn.Module]]) – Modules whose own parameters and child modules’ parameters and buffers are ignored by this instance. None of the modules directly in
ignored_modules
should beFullyShardedDataParallel
instances, and any child modules that are already-constructedFullyShardedDataParallel
instances will not be ignored if they are nested under this instance. This argument may be used to avoid sharding specific parameters at module granularity when using anauto_wrap_policy
or if parameters’ sharding is not managed by FSDP. (Default:None
)param_init_fn (Optional[Callable[[nn.Module], None]]) –
A
Callable[torch.nn.Module] -> None
that specifies how modules that are currently on the meta device should be initialized onto an actual device. Note that as of v1.12, we detect modules on the meta device viais_meta
check and apply a default initialization that callsreset_parameters
method on the passed innn.Module
ifparam_init_fn
is not specified, otherwise we runparam_init_fn
to initialize the passed innn.Module
. In particular, this means that ifis_meta=True
for any module parameters for modules that will be wrapped with FSDP andparam_init_fn
is not specified, we assume your module properly implements areset_parameters()
and will throw errors if not. Note that additionally, we offer support for modules initialized with torchdistX’s (https://github.com/pytorch/torchdistX)deferred_init
API. In this case, deferred modules would be initialized by a default initialization function that calls torchdistX’smaterialize_module
, or the passed inparam_init_fn
, if it is notNone
. The sameCallable
is applied to initialize all meta modules. Note that this initialization function is applied before doing any FSDP sharding logic.Example:
>>> module = MyModule(device="meta") >>> def my_init_fn(module): >>> # responsible for initializing a module, such as with reset_parameters >>> ... >>> fsdp_model = FSDP(module, param_init_fn=my_init_fn, auto_wrap_policy=size_based_auto_wrap_policy) >>> print(next(fsdp_model.parameters()).device) # current CUDA device >>> # With torchdistX >>> module = deferred_init.deferred_init(MyModule, device="cuda") >>> # Will initialize via deferred_init.materialize_module(). >>> fsdp_model = FSDP(module, auto_wrap_policy=size_based_auto_wrap_policy)
device_id (Optional[Union[int, torch.device]]) – An
int
ortorch.device
describing the CUDA device the FSDP module should be moved to determining where initialization such as sharding takes place. If this argument is not specified andmodule
is on CPU, we issue a warning mentioning that this argument can be specified for faster initialization. If specified, resulting FSDP instances will reside on this device, including moving ignored modules’ parameters if needed. Note that ifdevice_id
is specified butmodule
is already on a different CUDA device, an error will be thrown. (Default:None
)sync_module_states (bool) – If
True
, each individually wrapped FSDP unit will broadcast module parameters from rank 0 to ensure they are the same across all ranks after initialization. This helps ensure model parameters are the same across ranks before starting training, but adds communication overhead to__init__
, as at least one broadcast is triggered per individually wrapped FSDP unit. This can also help load checkpoints taken bystate_dict
and to be loaded byload_state_dict
in a memory efficient way. See documentation forFullStateDictConfig
for an example of this. (Default:False
)forward_prefetch (bool) – If
True
, then FSDP explicitly prefetches the next upcoming all-gather while executing in the forward pass. This may improve communication and computation overlap for CPU bound workloads. This should only be used for static graph models since the forward order is fixed based on the first iteration’s execution. (Default:False
)limit_all_gathers (bool) – If
False
, then FSDP allows the CPU thread to schedule all-gathers without any extra synchronization. IfTrue
, then FSDP explicitly synchronizes the CPU thread to prevent too many in-flight all-gathers. Thisbool
only affects the sharded strategies that schedule all-gathers. Enabling this can help lower the number of CUDA malloc retries.ignored_parameters (Optional[Iterable[torch.nn.Parameter]]) – Ignored parameters will not be managed by this FSDP instance, that means these parameters will not be flattened and sharded by FSDP, their gradients will not be synchronized as well. With this newly added argument,
ignored_modules
could be deprecated soon. For backward compatibility, bothignored_parameters
andignored_modules
are kept for now, but FSDP only allows one of them to be specified as notNone
.
- apply(fn)[source]¶
Applies
fn
recursively to every submodule (as returned by.children()
) as well as self. Typical use includes initializing the parameters of a model (see also torch.nn.init).Compared to
torch.nn.Module.apply
, this version additionally gathers the full parameters before applyingfn
. It should not be called from within anothersummon_full_params
context.- Parameters:
fn (
Module
-> None) – function to be applied to each submodule- Returns:
self
- Return type:
- Module
- clip_grad_norm_(max_norm, norm_type=2.0)[source]¶
Clips the gradient norm of all parameters. The norm is computed over all parameters’ gradients as viewed as a single vector, and the gradients are modified in-place.
- Parameters:
- Returns:
Total norm of the parameters (viewed as a single vector).
- Return type:
Note
If every FSDP instance uses
NO_SHARD
, meaning that no gradients are sharded across ranks, then you may directly usetorch.nn.utils.clip_grad_norm_()
.Note
If at least some FSDP instance uses a sharded strategy (i.e. one other than
NO_SHARD
), then you should use this method instead oftorch.nn.utils.clip_grad_norm_()
since this method handles the fact that gradients are sharded across ranks.Note
The total norm returned will have the “largest” dtype across all parameters/gradients as defined by PyTorch’s type promotion semantics. For example, if all parameters/gradients use a low precision dtype, then the returned norm’s dtype will be that low precision dtype, but if there exists at least one parameter/ gradient using FP32, then the returned norm’s dtype will be FP32.
Warning
This needs to be called on all ranks since it uses collective communications.
- static flatten_sharded_optim_state_dict(sharded_optim_state_dict, model, optim)[source]¶
The API is similar to
shard_full_optim_state_dict()
. The only difference is that the inputsharded_optim_state_dict
should be returned fromsharded_optim_state_dict()
. Therefore, there will be all-gather calls on each rank to gatherShardedTensor
s.- Parameters:
sharded_optim_state_dict (Dict[str, Any]) – Optimizer state dict corresponding to the unflattened parameters and holding the sharded optimizer state.
model (torch.nn.Module) – Refer to :meth:
shard_full_optim_state_dict
.optim (torch.optim.Optimizer) – Optimizer for
model
‘sparameters. –
- Returns:
Refer to
shard_full_optim_state_dict()
.- Return type:
- forward(*args, **kwargs)[source]¶
Runs the forward pass for the wrapped module, inserting FSDP-specific pre- and post-forward sharding logic.
- Return type:
- static fsdp_modules(module, root_only=False)[source]¶
Returns all nested FSDP instances, possibly including
module
itself and only including FSDP root modules ifroot_only=True
.- Parameters:
module (torch.nn.Module) – Root module, which may or may not be an
FSDP
module.root_only (bool) – Whether to return only FSDP root modules. (Default:
False
)
- Returns:
FSDP modules that are nested in the input
module
.- Return type:
List[FullyShardedDataParallel]
- static full_optim_state_dict(model, optim, optim_input=None, rank0_only=True, group=None)[source]¶
Consolidates the full optimizer state on rank 0 and returns it as a
dict
following the convention oftorch.optim.Optimizer.state_dict()
, i.e. with keys"state"
and"param_groups"
. The flattened parameters inFSDP
modules contained inmodel
are mapped back to their unflattened parameters.Warning
This needs to be called on all ranks since it uses collective communications. However, if
rank0_only=True
, then the state dict is only populated on rank 0, and all other ranks return an emptydict
.Warning
Unlike
torch.optim.Optimizer.state_dict()
, this method uses full parameter names as keys instead of parameter IDs.Note
Like in
torch.optim.Optimizer.state_dict()
, the tensors contained in the optimizer state dict are not cloned, so there may be aliasing surprises. For best practices, consider saving the returned optimizer state dict immediately, e.g. usingtorch.save()
.- Parameters:
model (torch.nn.Module) – Root module (which may or may not be a
FullyShardedDataParallel
instance) whose parameters were passed into the optimizeroptim
.optim (torch.optim.Optimizer) – Optimizer for
model
‘s parameters.optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – Input passed into the optimizer
optim
representing either alist
of parameter groups or an iterable of parameters; ifNone
, then this method assumes the input wasmodel.parameters()
. This argument is deprecated, and there is no need to pass it in anymore. (Default:None
)rank0_only (bool) – If
True
, saves the populateddict
only on rank 0; ifFalse
, saves it on all ranks. (Default:True
)group (dist.ProcessGroup) – Model’s process group or
None
if using the default process group. (Default:None
)
- Returns:
A
dict
containing the optimizer state formodel
‘s original unflattened parameters and including keys “state” and “param_groups” following the convention oftorch.optim.Optimizer.state_dict()
. Ifrank0_only=True
, then nonzero ranks return an emptydict
.- Return type:
Dict[str, Any]
- static load_optim_state_dict_pre_hook(model, optim, optim_state_dict, group=None)[source]¶
This hook is intended be used by
torch.distributed.NamedOptimizer
. The functionaility is identical to:meth:optim_state_dict_to_load
except for the different arguments.- Parameters:
model (torch.nn.Module) – Root module (which may or may not be a
FullyShardedDataParallel
instance) whose parameters were passed into the optimizeroptim
.optim (torch.optim.Optimizer) – Optimizer for
model
‘s parameters.optim_state_dict (Dict[str, Any]) – The optimizer states to be loaded.
group (dist.ProcessGroup) – Model’s process group across which parameters are sharded or
None
if using the default process group. ( Default:None
)
- Return type:
- named_buffers(*args, **kwargs)[source]¶
Overrides
named_buffers()
to intercept buffer names and remove all occurrences of the FSDP-specific flattened buffer prefix when inside thesummon_full_params()
context manager.
- named_parameters(*args, **kwargs)[source]¶
Overrides
named_parameters()
to intercept parameter names and remove all occurrences of the FSDP-specific flattened parameter prefix when inside thesummon_full_params()
context manager.
- no_sync()[source]¶
A context manager to disable gradient synchronizations across FSDP instances. Within this context, gradients will be accumulated in module variables, which will later be synchronized in the first forward-backward pass after exiting the context. This should only be used on the root FSDP instance and will recursively apply to all children FSDP instances.
Note
This likely results in higher memory usage because FSDP will accumulate the full model gradients (instead of gradient shards) until the eventual sync.
Note
When used with CPU offloading, the gradients will not be offloaded to CPU when inside the context manager. Instead, they will only be offloaded right after the eventual sync.
- Return type:
- static optim_state_dict(model, optim, group=None)[source]¶
Returns the state dict of
optim
for themodel
that is (partially) sharded by FSDP. The state may be sharded, consolidated, or consolidated on rank 0 only depending on thestate_dict_type
set byset_state_dict_type()
orstate_dict_type()
.Example:
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> from torch.distributed.fsdp import StateDictType >>> from torch.distributed.fsdp import FullStateDictConfig >>> from torch.distributed.fsdp import FullOptimStateDictConfig >>> # Save a checkpoint >>> model, optim = ... >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.FULL_STATE_DICT, >>> FullStateDictConfig(rank0_only=False), >>> FullOptimStateDictConfig(rank0_only=False), >>> ) >>> state_dict = model.state_dict() >>> optim_state_dict = FSDP.optim_state_dict(model, optim) >>> save_a_checkpoint(state_dict, optim_state_dict) >>> # Load a checkpoint >>> model, optim = ... >>> state_dict, optim_state_dict = load_a_checkponit() >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.FULL_STATE_DICT, >>> FullStateDictConfig(rank0_only=False), >>> FullOptimStateDictConfig(rank0_only=False), >>> ) >>> model.load_state_dict(state_dict) >>> optim_state_dict = FSDP.optim_state_dict_to_load( >>> optim_state_dict, model, optim >>> ) >>> optim.load_state_dict(optim_state_dict)
- Parameters:
model (torch.nn.Module) – Root module (which may or may not be a
FullyShardedDataParallel
instance) whose parameters were passed into the optimizeroptim
.optim (torch.optim.Optimizer) – Optimizer for
model
‘s parameters.group (dist.ProcessGroup) – Model’s process group across which parameters are sharded or
None
if using the default process group. ( Default:None
)
- Returns:
A
dict
containing the optimizer state formodel
. The sharding of the optimizer state is based onstate_dict_type
.- Return type:
Dict[str, Any]
- static optim_state_dict_post_hook(model, optim, optim_state_dict, group=None)[source]¶
This hook is intended be used by
torch.distributed.NamedOptimizer
. The functionaility is identical to:meth:optim_state_dict
except for the different arguments.- Parameters:
model (torch.nn.Module) – Root module (which may or may not be a
FullyShardedDataParallel
instance) whose parameters were passed into the optimizeroptim
.optim (torch.optim.Optimizer) – Optimizer for
model
‘s parameters.(Dict[str (optim) – the optim_state_dict to be coverted. The value is typically returned by
NamedOptimizer.state_dict()
.Any] – the optim_state_dict to be coverted. The value is typically returned by
NamedOptimizer.state_dict()
.group (dist.ProcessGroup) – Model’s process group across which parameters are sharded or
None
if using the default process group. ( Default:None
)
- Returns:
A
dict
containing the optimizer state formodel
. The sharding of the optimizer state is based onstate_dict_type
.- Return type:
Dict[str, Any]
- static optim_state_dict_to_load(optim_state_dict, model, optim, is_named_optimizer=False, group=None)[source]¶
Given a saved
optim_state_dict
, converts it to the optimizer state_dict that can be loaded tooptim
which is the optimizer formodel
.model
is (partially) sharded by FullyShardedDataParallel.>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> from torch.distributed.fsdp import StateDictType >>> from torch.distributed.fsdp import FullStateDictConfig >>> from torch.distributed.fsdp import FullOptimStateDictConfig >>> # Save a checkpoint >>> model, optim = ... >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.FULL_STATE_DICT, >>> FullStateDictConfig(rank0_only=False), >>> FullOptimStateDictConfig(rank0_only=False), >>> ) >>> state_dict = model.state_dict() >>> optim_state_dict = FSDP.optim_state_dict(model, optim) >>> save_a_checkpoint(state_dict, optim_state_dict) >>> # Load a checkpoint >>> model, optim = ... >>> state_dict, optim_state_dict = load_a_checkponit() >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.FULL_STATE_DICT, >>> FullStateDictConfig(rank0_only=False), >>> FullOptimStateDictConfig(rank0_only=False), >>> ) >>> model.load_state_dict(state_dict) >>> optim_state_dict = FSDP.optim_state_dict_to_load( >>> optim_state_dict, model, optim >>> ) >>> optim.load_state_dict(optim_state_dict)
- Parameters:
optim_state_dict (Dict[str, Any]) – The optimizer states to be loaded.
model (torch.nn.Module) – Root module (which may or may not be a
FullyShardedDataParallel
instance) whose parameters were passed into the optimizeroptim
.optim (torch.optim.Optimizer) – Optimizer for
model
‘s parameters.is_named_optimizer (bool) – Is this optimizer a NamedOptimizer or KeyedOptimizer. Only set to True if
optim
is TorchRec’s KeyedOptimizer or torch.distributed’s NamedOptimizer.group (dist.ProcessGroup) – Model’s process group across which parameters are sharded or
None
if using the default process group. ( Default:None
)
- Return type:
- register_comm_hook(state, hook)[source]¶
Registers a communication hook which is an enhancement that provides a flexible hook to users where they can specify how FSDP aggregates gradients across multiple workers. This hook can be used to implement several algorithms like GossipGrad and gradient compression which involve different communication strategies for parameter syncs while training with
FullyShardedDataParallel
.Warning
FSDP communication hook should be registered before running an initial forward pass and only once.
- Parameters:
state (object) –
Passed to the hook to maintain any state information during the training process. Examples include error feedback in gradient compression, peers to communicate with next in GossipGrad, etc. It is locally stored by each worker and shared by all the gradient tensors on the worker.
hook (Callable) – Callable, which has one of the following signatures: 1)
hook: Callable[torch.Tensor] -> None
: This function takes in a Python tensor, which represents the full, flattened, unsharded gradient with respect to all variables corresponding to the model this FSDP unit is wrapping (that are not wrapped by other FSDP sub-units). It then performs all necessary processing and returnsNone
; 2)hook: Callable[torch.Tensor, torch.Tensor] -> None
: This function takes in two Python tensors, the first one represents the full, flattened, unsharded gradient with respect to all variables corresponding to the model this FSDP unit is wrapping (that are not wrapped by other FSDP sub-units). The latter represents a pre-sized tensor to store a chunk of a sharded gradient after reduction. In both cases, callable performs all necessary processing and returnsNone
. Callables with signature 1 are expected to handle gradient communication for a NO_SHARD case. Callables with signature 2 are expected to handle gradient communication for sharded cases.
- static rekey_optim_state_dict(optim_state_dict, optim_state_key_type, model, optim_input=None, optim=None)[source]¶
Re-keys the optimizer state dict
optim_state_dict
to use the key typeoptim_state_key_type
. This can be used to achieve compatibility between optimizer state dicts from models with FSDP instances and ones without.To re-key an FSDP full optimizer state dict (i.e. from
full_optim_state_dict()
) to use parameter IDs and be loadable to a non-wrapped model:>>> wrapped_model, wrapped_optim = ... >>> full_osd = FSDP.full_optim_state_dict(wrapped_model, wrapped_optim) >>> nonwrapped_model, nonwrapped_optim = ... >>> rekeyed_osd = FSDP.rekey_optim_state_dict(full_osd, OptimStateKeyType.PARAM_ID, nonwrapped_model) >>> nonwrapped_optim.load_state_dict(rekeyed_osd)
To re-key a normal optimizer state dict from a non-wrapped model to be loadable to a wrapped model:
>>> nonwrapped_model, nonwrapped_optim = ... >>> osd = nonwrapped_optim.state_dict() >>> rekeyed_osd = FSDP.rekey_optim_state_dict(osd, OptimStateKeyType.PARAM_NAME, nonwrapped_model) >>> wrapped_model, wrapped_optim = ... >>> sharded_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, wrapped_model) >>> wrapped_optim.load_state_dict(sharded_osd)
- Returns:
The optimizer state dict re-keyed using the parameter keys specified by
optim_state_key_type
.- Return type:
Dict[str, Any]
- static scatter_full_optim_state_dict(full_optim_state_dict, model, optim_input=None, optim=None, group=None)[source]¶
Scatters the full optimizer state dict from rank 0 to all other ranks, returning the sharded optimizer state dict on each rank. The return value is the same as
shard_full_optim_state_dict()
, and on rank 0, the first argument should be the return value offull_optim_state_dict()
.Example:
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> model, optim = ... >>> full_osd = FSDP.full_optim_state_dict(model, optim) # only non-empty on rank 0 >>> # Define new model with possibly different world size >>> new_model, new_optim, new_group = ... >>> sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, new_model, group=new_group) >>> new_optim.load_state_dict(sharded_osd)
Note
Both
shard_full_optim_state_dict()
andscatter_full_optim_state_dict()
may be used to get the sharded optimizer state dict to load. Assuming that the full optimizer state dict resides in CPU memory, the former requires each rank to have the full dict in CPU memory, where each rank individually shards the dict without any communication, while the latter requires only rank 0 to have the full dict in CPU memory, where rank 0 moves each shard to GPU memory (for NCCL) and communicates it to ranks appropriately. Hence, the former has higher aggregate CPU memory cost, while the latter has higher communication cost.- Parameters:
full_optim_state_dict (Optional[Dict[str, Any]]) – Optimizer state dict corresponding to the unflattened parameters and holding the full non-sharded optimizer state if on rank 0; the argument is ignored on nonzero ranks.
model (torch.nn.Module) – Root module (which may or may not be a
FullyShardedDataParallel
instance) whose parameters correspond to the optimizer state infull_optim_state_dict
.optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – Input passed into the optimizer representing either a
list
of parameter groups or an iterable of parameters; ifNone
, then this method assumes the input wasmodel.parameters()
. This argument is deprecated, and there is no need to pass it in anymore. (Default:None
)optim (Optional[torch.optim.Optimizer]) – Optimizer that will load the state dict returned by this method. This is the preferred argument to use over
optim_input
. (Default:None
)group (dist.ProcessGroup) – Model’s process group or
None
if using the default process group. (Default:None
)
- Returns:
The full optimizer state dict now remapped to flattened parameters instead of unflattened parameters and restricted to only include this rank’s part of the optimizer state.
- Return type:
Dict[str, Any]
- static set_state_dict_type(module, state_dict_type, state_dict_config=None, optim_state_dict_config=None)[source]¶
Set the
state_dict_type
and the corresponding (optional) configurations of all the descendant FSDP modules of the target module. The target module does not have to be a FSDP module. If the target module is a FSDP module, itsstate_dict_type
will also be changed.Note
This API should be called for only the top-level (root) module.
Note
This API enables users to transparently use the conventional
state_dict
API to take model checkpoints in cases where the root FSDP module is wrapped by anothernn.Module
. For example, the following will ensurestate_dict
is called on all non-FSDP instances, while dispatching into sharded_state_dict implementation for FSDP:Example:
>>> model = DDP(FSDP(...)) >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.SHARDED_STATE_DICT, >>> state_dict_config = ShardedStateDictConfig(offload_to_cpu=True), >>> optim_state_dict_config = OptimStateDictConfig(offload_to_cpu=True), >>> ) >>> param_state_dict = model.state_dict() >>> optim_state_dict = FSDP.optim_state_dict(model, optim)
- Parameters:
module (torch.nn.Module) – Root module.
state_dict_type (StateDictType) – the desired
state_dict_type
to set.state_dict_config (Optional[StateDictConfig]) – the configuration for the target
state_dict_type
.
- Returns:
A StateDictSettings that include the previous state_dict type and configuration for the module.
- Return type:
StateDictSettings
- static shard_full_optim_state_dict(full_optim_state_dict, model, optim_input=None, optim=None)[source]¶
Shards the full optimizer state dict
full_optim_state_dict
by remapping the state to flattened parameters instead of unflattened parameters and restricting to only this rank’s part of the optimizer state. The first argument should be the return value offull_optim_state_dict()
.Example:
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> model, optim = ... >>> full_osd = FSDP.full_optim_state_dict(model, optim) >>> torch.save(full_osd, PATH) >>> # Define new model with possibly different world size >>> new_model, new_optim = ... >>> full_osd = torch.load(PATH) >>> sharded_osd = FSDP.shard_full_optim_state_dict(full_osd, new_model) >>> new_optim.load_state_dict(sharded_osd)
Note
Both
shard_full_optim_state_dict()
andscatter_full_optim_state_dict()
may be used to get the sharded optimizer state dict to load. Assuming that the full optimizer state dict resides in CPU memory, the former requires each rank to have the full dict in CPU memory, where each rank individually shards the dict without any communication, while the latter requires only rank 0 to have the full dict in CPU memory, where rank 0 moves each shard to GPU memory (for NCCL) and communicates it to ranks appropriately. Hence, the former has higher aggregate CPU memory cost, while the latter has higher communication cost.- Parameters:
full_optim_state_dict (Dict[str, Any]) – Optimizer state dict corresponding to the unflattened parameters and holding the full non-sharded optimizer state.
model (torch.nn.Module) – Root module (which may or may not be a
FullyShardedDataParallel
instance) whose parameters correspond to the optimizer state infull_optim_state_dict
.optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – Input passed into the optimizer representing either a
list
of parameter groups or an iterable of parameters; ifNone
, then this method assumes the input wasmodel.parameters()
. This argument is deprecated, and there is no need to pass it in anymore. (Default:None
)optim (Optional[torch.optim.Optimizer]) – Optimizer that will load the state dict returned by this method. This is the preferred argument to use over
optim_input
. (Default:None
)
- Returns:
The full optimizer state dict now remapped to flattened parameters instead of unflattened parameters and restricted to only include this rank’s part of the optimizer state.
- Return type:
Dict[str, Any]
- static sharded_optim_state_dict(model, optim, group=None)[source]¶
The API is similar to
full_optim_state_dict()
but this API chunks all non-zero-dimension states toShardedTensor
to save memory. This API should only be used when the modelstate_dict
is derived with the context managerwith state_dict_type(SHARDED_STATE_DICT):
.For the detailed usage, refer to
full_optim_state_dict()
.Warning
The returned state dict contains
ShardedTensor
and cannot be directly used by the regularoptim.load_state_dict
.
- static state_dict_type(module, state_dict_type, state_dict_config=None, optim_state_dict_config=None)[source]¶
A context manager to set the
state_dict_type
of all the descendant FSDP modules of the target module. This context manager has the same functions asset_state_dict_type()
. Read the document ofset_state_dict_type()
for the detail.Example:
>>> model = DDP(FSDP(...)) >>> with FSDP.state_dict_type( >>> model, >>> StateDictType.SHARDED_STATE_DICT, >>> ): >>> checkpoint = model.state_dict()
- Parameters:
module (torch.nn.Module) – Root module.
state_dict_type (StateDictType) – the desired
state_dict_type
to set.state_dict_config (Optional[StateDictConfig]) – the configuration for the target
state_dict_type
.
- Return type:
- static summon_full_params(module, recurse=True, writeback=True, rank0_only=False, offload_to_cpu=False, with_grads=False)[source]¶
A context manager to expose full params for FSDP instances. Can be useful after forward/backward for a model to get the params for additional processing or checking. It can take a non-FSDP module and will summon full params for all contained FSDP modules as well as their children, depending on the
recurse
argument.Note
This can be used on inner FSDPs.
Note
This can not be used within a forward or backward pass. Nor can forward and backward be started from within this context.
Note
Parameters will revert to their local shards after the context manager exits, storage behavior is the same as forward.
Note
The full parameters can be modified, but only the portion corresponding to the local param shard will persist after the context manager exits (unless
writeback=False
, in which case changes will be discarded). In the case where FSDP does not shard the parameters, currently only whenworld_size == 1
, orNO_SHARD
config, the modification is persisted regardless ofwriteback
.Note
This method works on modules which are not FSDP themselves but may contain multiple independent FSDP units. In that case, the given arguments will apply to all contained FSDP units.
Warning
Note that
rank0_only=True
in conjunction withwriteback=True
is not currently supported and will raise an error. This is because model parameter shapes would be different across ranks within the context, and writing to them can lead to inconsistency across ranks when the context is exited.Warning
Note that
offload_to_cpu
andrank0_only=False
will result in full parameters being redundantly copied to CPU memory for GPUs that reside on the same machine, which may incur the risk of CPU OOM. It is recommended to useoffload_to_cpu
withrank0_only=True
.- Parameters:
recurse (bool, Optional) – recursively summon all params for nested FSDP instances (default: True).
writeback (bool, Optional) – if
False
, modifications to params are discarded after the context manager exits; disabling this can be slightly more efficient (default: True)rank0_only (bool, Optional) – if
True
, full parameters are materialized on only global rank 0. This means that within the context, only rank 0 will have full parameters and the other ranks will have sharded parameters. Note that settingrank0_only=True
withwriteback=True
is not supported, as model parameter shapes will be different across ranks within the context, and writing to them can lead to inconsistency across ranks when the context is exited.offload_to_cpu (bool, Optional) – If
True
, full parameters are offloaded to CPU. Note that this offloading currently only occurs if the parameter is sharded (which is only not the case for world_size = 1 orNO_SHARD
config). It is recommended to useoffload_to_cpu
withrank0_only=True
to avoid redundant copies of model parameters being offloaded to the same CPU memory.with_grads (bool, Optional) – If
True
, gradients are also unsharded with the parameters. Currently, this is only supported when passinguse_orig_params=True
to the FSDP constructor andoffload_to_cpu=False
to this method. (Default:False
)
- Return type:
- class torch.distributed.fsdp.BackwardPrefetch(value)[source]¶
This configures explicit backward prefetching, which can improve throughput but may slightly increase peak memory usage.
For NCCL backend, any collectives, even if issued in different streams, contend for the same per-device NCCL stream, which is why the relative order in which the collectives are issued matters for overlapping. The different backward prefetching settings correspond to different orderings.
BACKWARD_PRE
: This prefetches the next set of parameters before the current set of parameter’s gradient computation. This improves backward pass throughput by overlapping communication (next all-gather) and computation (current gradient computation).BACKWARD_POST
: This prefetches the next set of parameters after the current set of parameter’s gradient computation. This may improve backward pass throughput by overlapping communication (current reduce-scatter) and computation (next gradient computation). Specifically, the next all-gather is reordered to be before the current reduce-scatter.
Note
If the increase in peak memory usage from prefetching is an issue, you may consider passing
limit_all_gathers=True
to the FSDP constructor, which may help reduce peak memory usage in some cases.
- class torch.distributed.fsdp.ShardingStrategy(value)[source]¶
This specifies the sharding strategy to be used for distributed training by
FullyShardedDataParallel
.FULL_SHARD
: Parameters, gradients, and optimizer states are sharded. For the parameters, this strategy unshards (via all-gather) before the forward, reshards after the forward, unshards before the backward computation, and reshards after the backward computation. For gradients, it synchronizes and shards them (via reduce-scatter) after the backward computation. The sharded optimizer states are updated locally per rank.SHARD_GRAD_OP
: Gradients and optimizer states are sharded during computation, and additionally, parameters are sharded outside computation. For the parameters, this strategy unshards before the forward, does not reshard them after the forward, and only reshards them after the backward computation. The sharded optimizer states are updated locally per rank. Insideno_sync()
, the parameters are not resharded after the backward computation.NO_SHARD
: Parameters, gradients, and optimizer states are not sharded but instead replicated across ranks similar to PyTorch’sDistributedDataParallel
API. For gradients, this strategy synchronizes them (via all-reduce) after the backward computation. The unsharded optimizer states are updated locally per rank.HYBRID_SHARD
: ApplyFULL_SHARD
within a node, and replicate parameters acrossnodes. This results in reduced communication volume as expensive all-gathers and reduce-scatters are only done within a node, which can be more performant for medium -sized models.
_HYBRID_SHARD_ZERO2
: ApplySHARD_GRAD_OP
within a node, and replicate parameters acrossnodes. This is like
HYBRID_SHARD
, except this may provide even higher throughput since the unsharded parameters are not freed after the forward pass, saving the all-gathers in the pre-backward.
- class torch.distributed.fsdp.MixedPrecision(param_dtype=None, reduce_dtype=None, buffer_dtype=None, keep_low_precision_grads=False, cast_forward_inputs=False, cast_root_forward_inputs=True)[source]¶
This configures FSDP-native mixed precision training.
- Variables:
param_dtype (torch.dtype) – This specifies the dtype for model parameters, inputs (when
cast_forward_inputs
orcast_root_forward_inputs``is set to ``True
), and therefore the dtype for computation. However, outside the forward and backward passes, parameters are in full precision. Model checkpointing always happens in full precision.reduce_dtype (torch.dtype) – This specifies the dtype for gradient reduction, which is permitted to differ from
param_dtype
.buffer_dtype (torch.dtype) – This specifies the dtype for buffers. FSDP does not shard buffers, casts them to
buffer_dtype
in the first forward pass, and keeps them in that dtype thereafter. Model checkpointing always happens in full precision.keep_low_precision_grads (bool) – This specifies whether to upcast gradients back to the full parameter precision after the backward pass. This may be set to
False
to save memory if using custom optimizers that can perform the optimizer step inreduce_dtype
. (Default:False
)cast_forward_inputs (bool) – Cast floating point tensors in the forward arguments and keyword arguments to
param_dtype
. (Default:False
)cast_root_forward_inputs (bool) – Cast floating point tensors in the forward arguments and keyword arguments to
param_dtype
for the root FSDP instance. It takes precedence overcast_forward_inputs
for the root FSDP instance. (Default:True
)
Note
This API is experimental and subject to change.
Note
Only floating point tensors are cast to their specified dtypes.
Note
In
summon_full_params
, parameters are forced to full precision, but buffers are not.Note
state_dict
checkpoints parameters and buffers in full precision. For buffers, this is only supported forStateDictType.FULL_STATE_DICT
.Note
Each low precision dtype must be specified explicitly. For example,
MixedPrecision(reduce_dtype=torch.float16)
only specifies the reduction dtype to be low precision, and FSDP will not cast parameters or buffers.Note
If a
reduce_dtype
is not specified, then gradient reduction happens inparam_dtype
if specified or the original parameter dtype otherwise.Note
If the user passes a model with
BatchNorm
modules and anauto_wrap_policy
to the FSDP constructor, then FSDP will disable mixed precision forBatchNorm
modules by wrapping them separately in their own FSDP instance with mixed precision disabled. This is due to some missing low precisionBatchNorm
kernels. If the user does not use anauto_wrap_policy
, then the user must take care to not use mixed precision for FSDP instances containingBatchNorm
modules.Note
MixedPrecision
hascast_root_forward_inputs=True
andcast_forward_inputs=False
by default. For the root FSDP instance, itscast_root_forward_inputs
takes precedence over itscast_forward_inputs
. For non-root FSDP instances, theircast_root_forward_inputs
values are ignored. The default setting is sufficient for the typical case where each FSDP instance has the sameMixedPrecision
configuration and only needs to cast inputs to theparam_dtype
at the beginning of the model’s forward pass.Note
For nested FSDP instances with different
MixedPrecision
configurations, we recommend setting individualcast_forward_inputs
values to configure casting inputs or not before each instance’s forward. In such a case, since the casts happen before each FSDP instance’s forward, a parent FSDP instance should have its non-FSDP submodules run before its FSDP submodules to avoid the activation dtype being changed due to a differentMixedPrecision
configuration.Example:
>>> model = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3)) >>> model[1] = FSDP( >>> model[1], >>> mixed_precision=MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True), >>> ) >>> model = FSDP( >>> model, >>> mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True), >>> )
The above shows a working example. On the other hand, if
model[1]
were replaced withmodel[0]
, meaning that the submodule using differentMixedPrecision
ran its forward first, thenmodel[1]
would incorrectly seefloat16
activations instead ofbfloat16
ones.
- class torch.distributed.fsdp.CPUOffload(offload_params=False)[source]¶
This configures CPU offloading.
- Variables:
offload_params (bool) – This specifies whether to offload parameters to CPU when not involved in computation. If enabled, this implicitly offloads gradients to CPU as well. This is to support the optimizer step, which requires parameters and gradients to be on the same device.