DidiZhu commited on
Commit
20ccb6c
·
verified ·
1 Parent(s): 51e3e10

Delete files _fully_shard.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. _fully_shard.py +0 -672
_fully_shard.py DELETED
@@ -1,672 +0,0 @@
1
- # mypy: allow-untyped-decorators
2
- # mypy: allow-untyped-defs
3
-
4
- from __future__ import annotations
5
-
6
- import functools
7
- from typing import (
8
- Any,
9
- Callable,
10
- cast,
11
- NoReturn,
12
- Optional,
13
- overload,
14
- TYPE_CHECKING,
15
- Union,
16
- )
17
- from typing_extensions import deprecated
18
-
19
- import torch
20
- import torch.nn as nn
21
- from torch.distributed._composable import contract
22
- from torch.distributed.utils import _get_root_modules
23
-
24
- from ._fsdp_api import MixedPrecisionPolicy, OffloadPolicy
25
- from ._fsdp_common import FSDPMeshInfo, HSDPMeshInfo
26
- from ._fsdp_init import (
27
- _get_device_from_mesh,
28
- _get_managed_modules,
29
- _get_managed_states,
30
- _get_post_forward_mesh_info,
31
- _init_default_fully_shard_mesh,
32
- _move_states_to_device,
33
- )
34
- from ._fsdp_param_group import FSDPParamGroup
35
- from ._fsdp_state import _get_module_fsdp_state, FSDPState
36
-
37
-
38
- if TYPE_CHECKING:
39
- from collections.abc import Iterable
40
-
41
- from torch.distributed.tensor import DeviceMesh, Shard
42
-
43
- __all__ = [
44
- "fully_shard",
45
- "FSDPModule",
46
- "UnshardHandle",
47
- "register_fsdp_forward_method",
48
- ]
49
-
50
-
51
- cls_to_fsdp_cls: dict[type, type] = {}
52
-
53
-
54
- @overload
55
- def fully_shard(
56
- module: nn.Module,
57
- *,
58
- mesh: Optional[DeviceMesh] = ...,
59
- reshard_after_forward: Union[bool, int] = ...,
60
- shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]] = ...,
61
- mp_policy: MixedPrecisionPolicy = ...,
62
- offload_policy: OffloadPolicy = ...,
63
- ignored_params: Optional[set[nn.Parameter]] = ...,
64
- ) -> FSDPModule: ...
65
-
66
-
67
- @overload
68
- def fully_shard(
69
- module: list[nn.Module],
70
- *,
71
- mesh: Optional[DeviceMesh] = ...,
72
- reshard_after_forward: Union[bool, int] = ...,
73
- shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]] = ...,
74
- mp_policy: MixedPrecisionPolicy = ...,
75
- offload_policy: OffloadPolicy = ...,
76
- ignored_params: Optional[set[nn.Parameter]] = ...,
77
- ) -> list[FSDPModule]: ...
78
-
79
-
80
- # The decorator adds a state object to `module` that can be accessed via
81
- # `fully_shard.state(module)`. The state object and module are 1:1.
82
- # [1] Python runtime decorator does not play well with static type checking
83
- # so suppressing some type checks to support type overloads
84
- # such that caller can still get correct return types based on input type
85
- @contract(state_cls=FSDPState) # type: ignore[misc] # see [1]
86
- def fully_shard(
87
- module,
88
- *,
89
- mesh: Optional[DeviceMesh] = None,
90
- reshard_after_forward: Optional[Union[bool, int]] = None,
91
- shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]] = None,
92
- mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(),
93
- offload_policy: OffloadPolicy = OffloadPolicy(),
94
- ignored_params: Optional[set[nn.Parameter]] = None,
95
- ):
96
- """
97
- Apply fully sharded data parallelism (FSDP) to ``module``, where FSDP
98
- shards module parameters, gradients, and optimizer states across data
99
- parallel workers to save memory at the cost of communication.
100
-
101
- At initialization, FSDP shards the module's parameters across the data
102
- parallel workers given by ``mesh``. Before forward, FSDP all-gathers the
103
- sharded parameters across the data-parallel workers to get the unsharded
104
- parameters for forward computation. If ``reshard_after_forward`` is
105
- ``True``, then FSDP frees the unsharded parameters after forward and
106
- re-all-gathers them in backward before gradient computation. After gradient
107
- computation, FSDP frees the unsharded parameters and reduce-scatters the
108
- unsharded gradients across data-parallel workers.
109
-
110
- This implementation represents the sharded parameters as :class:`DTensor` s
111
- sharded on dim-0, while the unsharded parameters will be like the original
112
- parameters on ``module`` (e.g. :class:`torch.Tensor` if originally
113
- :class:`torch.Tensor`). A module
114
- `forward pre-hook <https://pytorch.org/docs/main/generated/torch.nn.Module.html#torch.nn.Module.register_forward_pre_hook>`_
115
- on ``module`` all-gathers the parameters, and a module
116
- `forward hook <https://pytorch.org/docs/main/generated/torch.nn.Module.html#torch.nn.Module.register_forward_hook>`_
117
- on ``module`` frees them (if needed). Similar backward hooks all-gather
118
- parameters and later free parameters and reduce-scatter gradients.
119
-
120
- Since grouping multiple tensors together for one collective is critical for
121
- communication efficiency, this implementation makes this grouping first
122
- class. Calling :meth:`fully_shard` on ``module`` constructs one group that
123
- includes the parameters in ``module.parameters()`` except those already
124
- assigned to a group from an earlier call on a submodule. This means that
125
- :meth:`fully_shard` should be called bottom-up on your model. Each group's
126
- parameters are all-gathered in one collective, and its gradients are
127
- reduce-scattered in one collective. Partitioning the model into multiple
128
- groups ("layer by layer") allows for peak memory savings and communication/computation
129
- overlap. Users generally should *not* call :meth:`fully_shard` only on the
130
- topmost root module.
131
-
132
- Args:
133
- module (Union[nn.Module, List[nn.Module]): The module or modules to
134
- shard with FSDP and group together for communication.
135
- mesh (Optional[DeviceMesh]): This data parallel mesh defines the
136
- sharding and device. If 1D, then parameters are fully sharded
137
- across the 1D mesh (FSDP) with ``(Shard(0),)`` placement. If 2D,
138
- then parameters are sharded across the 1st dim and replicated
139
- across the 0th dim (HSDP) with ``(Replicate(), Shard(0))``
140
- placement. The mesh's device type gives the device type used for
141
- communication; if a CUDA or CUDA-like device type, then we use the
142
- current device.
143
- reshard_after_forward (Optional[Union[bool, int]]): This controls the parameter
144
- behavior after forward and can trade off memory and communication:
145
-
146
- - If ``True``, then this reshards parameters after forward and
147
- re-all-gathers in backward.
148
- - If ``False``, then this keeps the unsharded parameters in memory
149
- after forward and avoids the all-gather in backward. For best performance,
150
- we usually set ``False`` for the root module, because the root module
151
- is typically required immediately when the backward pass begins.
152
- - If ``None``, it is set to ``True`` for non-root modules and ``False``
153
- for root modules.
154
- - If an ``int``, then this represents the world size to reshard to
155
- after forward. It should be a non-trivial divisor of the ``mesh``
156
- shard dim size (i.e. excluding 1 and the dim size itself). A
157
- choice may be the intra-node size (e.g. ``torch.cuda.device_count()``).
158
- This allows the all-gather in backward to be over a smaller world
159
- size at the cost of higher memory usage than setting to ``True``.
160
- - After forward, the parameters registered to the module depend on
161
- to this: The registered parameters are the sharded parameters if
162
- ``True``; unsharded parameters if ``False``; and the parameters
163
- resharded to the smaller mesh otherwise. To modify the parameters
164
- between forward and backward, the registered parameters must be
165
- the sharded parameters. For ``False`` or an ``int``, this can be
166
- done by manually resharding via :meth:`reshard`.
167
- shard_placement_fn (Optional[Callable[[nn.Parameter], Optional[Shard]]]):
168
- This callable can be used to override the sharding placement for a
169
- parameter to shard a parameter on a dimension other than dim-0. If
170
- this callable returns a :class:`Shard` placement (not ``None``),
171
- then FSDP will shard according to that placement (e.g. ``Shard(1)``).
172
- If sharding on a nonzero dim, we currently require even sharding,
173
- i.e. the tensor dim size on that dim must be divisible by the FSDP
174
- shard mesh size.
175
- mp_policy (MixedPrecisionPolicy): This controls the mixed precision
176
- policy, which offers parameter/reduction mixed precision for this
177
- module. See :class:`MixedPrecisionPolicy` for details.
178
- offload_policy (OffloadPolicy): This controls the offloading policy,
179
- which offers parameter/gradient/optimizer state offloading. See
180
- :class:`OffloadPolicy` and its subclasses for details.
181
- ignored_params: Optional(Set[nn.Parameter]): The set of parameters to be
182
- ignored by FSDP. They will not be sharded, nor moved to the device
183
- during init, nor have their gradients reduced in backward.
184
-
185
- Returns:
186
- FSDPModule: The module with FSDP applied (in-place).
187
- """
188
- torch._C._log_api_usage_once("torch.distributed.fsdp.fully_shard")
189
- if isinstance(module, (nn.ModuleList, nn.ModuleDict)):
190
- raise ValueError(
191
- f"fully_shard does not support containers that do not implement forward: {module}"
192
- )
193
- mesh = mesh or _init_default_fully_shard_mesh()
194
- if mesh.ndim not in (1, 2):
195
- raise ValueError(f"fully_shard expects a 1D or 2D DeviceMesh but got {mesh}")
196
- elif mesh.ndim == 1:
197
- mesh_info = FSDPMeshInfo(mesh, shard_mesh_dim=0)
198
- else:
199
- if mesh.mesh_dim_names is None:
200
- raise AssertionError(
201
- "Please init the 2D mesh for HSDP with mesh_dim_names specified"
202
- )
203
- mesh_info = HSDPMeshInfo(mesh, shard_mesh_dim=1, replicate_mesh_dim=0)
204
- device = _get_device_from_mesh(mesh)
205
- auto_reshard_after_forward = reshard_after_forward is None
206
- # If the user does not provide ``reshard_after_forward``, we set it to True.
207
- # During lazy_init, we identify which module is the root and override its value to False
208
- post_forward_mesh_info = _get_post_forward_mesh_info(
209
- reshard_after_forward if not auto_reshard_after_forward else True, # type: ignore[arg-type]
210
- mesh_info,
211
- )
212
-
213
- arg_module = module
214
- modules = (
215
- (module,) if isinstance(module, nn.Module) else tuple(_get_root_modules(module))
216
- )
217
- state = fully_shard.state(modules[0]) # type: ignore[attr-defined] # see [1]
218
- state.init(modules, device, mp_policy, auto_reshard_after_forward)
219
-
220
- managed_modules = _get_managed_modules(modules, ignored_params)
221
- params, buffers = _get_managed_states(managed_modules, ignored_params)
222
-
223
- _move_states_to_device(params, buffers, device)
224
- if params:
225
- state._fsdp_param_group = FSDPParamGroup(
226
- params,
227
- modules,
228
- mesh_info,
229
- post_forward_mesh_info,
230
- device,
231
- shard_placement_fn,
232
- mp_policy,
233
- offload_policy,
234
- )
235
-
236
- # For Dynamo
237
- for managed_module in managed_modules:
238
- managed_module._is_fsdp_managed_module = True # type: ignore[assignment]
239
- managed_module._fsdp_use_orig_params = True # type: ignore[assignment]
240
-
241
- # Place FSDP leftmost for highest priority in the method resolution order
242
- for module in modules:
243
- cls = module.__class__
244
- new_cls = cls_to_fsdp_cls.get(cls, None)
245
- if not new_cls:
246
- dct = {"__deepcopy__": _unimplemented_deepcopy}
247
- new_cls = type(f"FSDP{cls.__name__}", (FSDPModule, cls), dct)
248
- cls_to_fsdp_cls[cls] = new_cls
249
- module.__class__ = new_cls
250
- return arg_module
251
-
252
-
253
- def _unimplemented_deepcopy(*args: Any, **kwargs: Any) -> NoReturn:
254
- raise AssertionError(
255
- "FSDP does not support deepcopy. Please use state dict for serialization."
256
- )
257
-
258
-
259
- class FSDPModule:
260
- def __new__(cls, *args, **kwargs):
261
- """
262
- Override ``__new__`` to remove the FSDP class and directly construct
263
- the original class for cases like indexing into a container module.
264
- """
265
- # Use index 2 since 0 is the dynamically constructed `FSDP<...>` class
266
- # and index 1 is the `FSDPModule` class itself
267
- orig_cls = cls.__mro__[2]
268
- self = orig_cls.__new__(orig_cls, *args, **kwargs)
269
- self.__init__(*args, **kwargs)
270
- return self
271
-
272
- def reshard(self) -> None:
273
- """
274
- Reshards the module's parameters, freeing the unsharded parameters if
275
- they are allocated and registering the sharded parameters to the
276
- module. This method is *not* recursive.
277
- """
278
- state = self._get_fsdp_state()
279
- if fsdp_param_group := state._fsdp_param_group:
280
- fsdp_param_group.reshard()
281
-
282
- def unshard(self, async_op: bool = False) -> Optional[UnshardHandle]:
283
- """
284
- Unshards the module's parameters by allocating memory and all-gathering
285
- the parameters. This method is *not* recursive. The unshard follows the
286
- :class:`MixedPrecisionPolicy`, so it will all-gather following
287
- ``param_dtype`` if set.
288
-
289
- Args:
290
- async_op (bool): If ``True``, then returns a :class:`UnshardHandle`
291
- that has a :meth:`wait` method to wait on the unshard op. If
292
- ``False``, then returns ``None`` and waits on the handle inside
293
- this function.
294
-
295
- .. note:: If ``async_op=True``, then FSDP will wait on the pending
296
- unshard in the module's pre-forward for the user. The user only
297
- needs to call :meth:`wait` explicitly if the wait should happen
298
- before pre-forward.
299
- """
300
- state = self._get_fsdp_state()
301
- fsdp_param_group = state._fsdp_param_group
302
- if fsdp_param_group is not None:
303
- fsdp_param_group.lazy_init()
304
- fsdp_param_group.unshard(async_op=async_op)
305
- handle = _UnshardHandleImpl(fsdp_param_group)
306
- if async_op:
307
- return handle
308
- handle.wait()
309
- return None
310
-
311
- def set_is_last_backward(self, is_last_backward: bool) -> None:
312
- """
313
- Sets whether the next backward is the last one. On the last backward,
314
- FSDP waits on pending gradient reduction and clears internal data
315
- data structures for backward prefetching. This can be useful for
316
- microbatching.
317
- """
318
- state = self._get_fsdp_state()
319
- state._state_ctx.is_last_backward = is_last_backward
320
-
321
- def set_requires_gradient_sync(
322
- self, requires_gradient_sync: bool, *, recurse: bool = True
323
- ) -> None:
324
- """
325
- Sets if the module should sync gradients. This can be used to implement
326
- gradient accumulation *without communication*. For HSDP, this controls
327
- both reduce-scatter and all-reduce together. This is the equivalence of
328
- `no_sync` in FSDP1.
329
-
330
- Args:
331
- requires_gradient_sync (bool): Whether to reduce gradients for the
332
- module's parameters.
333
- recurse (bool): Whether to set for all FSDP submodules or just the
334
- passed-in module.
335
- """
336
- self_module = cast(nn.Module, self)
337
- modules = list(self_module.modules()) if recurse else [self_module]
338
- for module in modules:
339
- if isinstance(module, FSDPModule):
340
- state = module._get_fsdp_state()
341
- if fsdp_param_group := state._fsdp_param_group:
342
- fsdp_param_group.reduce_grads = requires_gradient_sync
343
- fsdp_param_group.all_reduce_grads = requires_gradient_sync
344
-
345
- def set_requires_all_reduce(
346
- self, requires_all_reduce: bool, *, recurse: bool = True
347
- ) -> None:
348
- """
349
- Sets if the module should all-reduce gradients. This can be used to
350
- implement gradient accumulation with only reduce-scatter but not
351
- all-reduce for HSDP.
352
- """
353
- self_module = cast(nn.Module, self)
354
- modules = list(self_module.modules()) if recurse else [self_module]
355
- for module in modules:
356
- if isinstance(module, FSDPModule):
357
- state = module._get_fsdp_state()
358
- if fsdp_param_group := state._fsdp_param_group:
359
- fsdp_param_group.all_reduce_grads = requires_all_reduce
360
-
361
- def set_reshard_after_forward(
362
- self, reshard_after_forward: bool, recurse: bool = True
363
- ) -> None:
364
- """
365
- Sets if the module should reshard parameters after forward. This can be
366
- used to change the ``reshard_after_forward`` FSDP arg at runtime. For
367
- example, this can be used to set the FSDP root module's value to
368
- ``True`` (since it is otherwise specially set to ``False``), or it can
369
- set an FSDP module's value to ``False`` for running evals and set back
370
- to ``True`` for training.
371
-
372
- Args:
373
- reshard_after_forward (bool): Whether to reshard parameters after
374
- forward.
375
- recurse (bool): Whether to set for all FSDP submodules or just the
376
- passed-in module.
377
- """
378
- if not isinstance(reshard_after_forward, bool):
379
- raise ValueError(
380
- f"reshard_after_forward should be a bool, got {type(reshard_after_forward)}"
381
- )
382
- self_module = cast(nn.Module, self)
383
- modules = list(self_module.modules()) if recurse else [self_module]
384
- for module in modules:
385
- if isinstance(module, FSDPModule):
386
- state = module._get_fsdp_state()
387
- state._auto_reshard_after_forward = False
388
- if fsdp_param_group := state._fsdp_param_group:
389
- fsdp_param_group.post_forward_mesh_info = (
390
- _get_post_forward_mesh_info(
391
- reshard_after_forward, fsdp_param_group.mesh_info
392
- )
393
- )
394
-
395
- def set_reshard_after_backward(
396
- self, reshard_after_backward: bool, *, recurse: bool = True
397
- ) -> None:
398
- """
399
- Sets if the module should reshard parameters after backward. This can
400
- be used during gradient accumulation to trade off higher memory for
401
- reduced communication since the unsharded parameters do not need to be
402
- re-all-gathered before the next forward.
403
-
404
- Args:
405
- reshard_after_backward (bool): Whether to reshard parameters after
406
- backward.
407
- recurse (bool): Whether to set for all FSDP submodules or just the
408
- passed-in module.
409
- """
410
- self_module = cast(nn.Module, self)
411
- modules = list(self_module.modules()) if recurse else [self_module]
412
- for module in modules:
413
- if isinstance(module, FSDPModule):
414
- state = module._get_fsdp_state()
415
- if fsdp_param_group := state._fsdp_param_group:
416
- fsdp_param_group.reshard_after_backward = reshard_after_backward
417
-
418
- def set_modules_to_forward_prefetch(self, modules: list[FSDPModule]) -> None:
419
- """
420
- Sets the FSDP modules for which this FSDP module should explicitly
421
- prefetch all-gathers in forward. The prefetching runs after this
422
- module's all-gather copy-out.
423
-
424
- Passing a singleton list containing the next FSDP module gives the same
425
- all-gather overlap behavior as the default overlap behavior, except the
426
- prefetched all-gather is issued earlier from the CPU. Passing a list
427
- with at least length two is required for more aggressive overlap and
428
- will use more reserved memory.
429
-
430
- Args:
431
- modules (List[FSDPModule]): FSDP modules to prefetch.
432
- """
433
- _assert_all_fsdp_modules(modules)
434
- self._get_fsdp_state()._states_to_forward_prefetch = [
435
- module._get_fsdp_state() for module in modules
436
- ]
437
-
438
- def set_modules_to_backward_prefetch(self, modules: list[FSDPModule]) -> None:
439
- """
440
- Sets the FSDP modules for which this FSDP module should explicitly
441
- prefetch all-gathers in backward. This overrides the default backward
442
- pretching implementation that prefetches the next FSDP module based on
443
- the reverse post-forward order.
444
-
445
- Passing a singleton list containing the previous FSDP module gives the
446
- same all-gather overlap behavior as the default overlap behavior.
447
- Passing a list with at least length two is required for more aggressive
448
- overlap and will use more reserved memory.
449
-
450
- Args:
451
- modules (List[FSDPModule]): FSDP modules to prefetch.
452
- """
453
- _assert_all_fsdp_modules(modules)
454
- self._get_fsdp_state()._states_to_backward_prefetch = [
455
- module._get_fsdp_state() for module in modules
456
- ]
457
-
458
- def set_all_reduce_hook(
459
- self,
460
- hook: Callable[[torch.Tensor], None],
461
- *,
462
- stream: Optional[torch.cuda.Stream] = None,
463
- ):
464
- """
465
- Args:
466
- hook (Callable[[torch.Tensor], None]): User-defined all-reduce hook
467
- with expected signature ``hook(reduce_output: torch.Tensor) -> None``
468
- where ``reduce_output`` is the reduce-scatter output if only
469
- using FSDP or the all-reduce output if using native HSDP.
470
- stream (Optional[torch.cuda.Stream]): Stream to run the all-reduce
471
- hook in. This should only be set if not using native HSDP. If
472
- using native HSDP, the hook will run in the internally defined
473
- all-reduce stream used by the native HSDP all-reduce.
474
- """
475
- state = self._get_fsdp_state()
476
- if (fsdp_param_group := state._fsdp_param_group) is not None:
477
- fsdp_param_group._all_reduce_hook = hook
478
- if stream is not None:
479
- if fsdp_param_group._is_hsdp:
480
- raise ValueError("stream cannot be set when using native HSDP")
481
- fsdp_param_group._all_reduce_hook_stream = stream
482
-
483
- def set_post_optim_event(self, event: torch.Event) -> None:
484
- """
485
- Sets a post-optimizer-step event for the root FSDP module to wait the
486
- all-gather streams on.
487
-
488
- By default, the root FSDP module waits the all-gather streams on the
489
- current stream to ensure that the optimizer step has finished before
490
- all-gathering. However, this may introduce false dependencies if
491
- there is unrelated computation after the optimizer step. This API
492
- allows the user to provide their own event to wait on. After the root
493
- waits on the event, the event is discarded, so this API should be
494
- called with a new event each iteration.
495
-
496
- Args:
497
- event (torch.Event): Event recorded after the optimizer step
498
- to wait all-gather streams on.
499
- """
500
- self._get_fsdp_state()._state_ctx.post_optim_event = event
501
-
502
- @deprecated("Use `set_gradient_divide_factor` instead")
503
- def set_reduce_scatter_divide_factor(self, factor: float) -> None:
504
- """Use :py:meth:`set_gradient_divide_factor` instead"""
505
- self.set_gradient_divide_factor(factor)
506
-
507
- def set_gradient_divide_factor(self, factor: float) -> None:
508
- """
509
- Sets a custom divide factor for the gradient reduction. This might use
510
- a custom reduce op using NCCL's PreMulSum, which allows multiplying by
511
- the factor before reduction.
512
-
513
- Args:
514
- factor (float): Custom divide factor.
515
- """
516
- state = self._get_fsdp_state()
517
- if (fsdp_param_group := state._fsdp_param_group) is not None:
518
- fsdp_param_group.gradient_divide_factor = factor
519
-
520
- def set_force_sum_reduction_for_comms(self, enable: bool) -> None:
521
- """
522
- Sets whether to require the low-level collective communication
523
- primitives to exclusively use "sum"-type reductions, even if it comes
524
- at the cost of separate additional pre- or post-scaling operations.
525
- This is needed for example because NCCL currently supports zero-copy
526
- transfers only for this kind of collectives.
527
-
528
- NB: for MTIA devices, this is always implicitly enabled.
529
-
530
- NB: if `set_all_reduce_hook` is used under FSDP setup, the caller needs
531
- to ensure the custom all-reduce across FSDP units follow this strategy
532
- as well, as FSDP can no longer automatically handle that.
533
-
534
- Args:
535
- enable (bool): Whether to only ever use ReduceOp.SUM for comms.
536
- """
537
- state = self._get_fsdp_state()
538
- if (fsdp_param_group := state._fsdp_param_group) is not None:
539
- fsdp_param_group.force_sum_reduction_for_comms = enable
540
-
541
- def set_unshard_in_backward(self, unshard_in_backward: bool) -> None:
542
- """
543
- Sets whether the FSDP module's parameters need to be unsharded in
544
- backward. This can be used in expert cases when the user knows that all
545
- parameters in this FSDP module's parameter group are not needed for
546
- backward computation (e.g. embedding).
547
- """
548
- state = self._get_fsdp_state()
549
- if (fsdp_param_group := state._fsdp_param_group) is not None:
550
- fsdp_param_group.unshard_in_backward = unshard_in_backward
551
-
552
- def set_allocate_memory_from_process_group_for_comm(self, enable: bool) -> None:
553
- """
554
- Sets whether the temporary staging buffers used to send and receive data
555
- over collective communications should be allocated using the custom
556
- optimized allocator provided by the ProcessGroup itself (if any). This
557
- might allow the ProcessGroup to be more efficient. For example, when
558
- using NCCL, this enables it to leverage zero-copy transfers over SHARP
559
- (for NVLink and/or InfiniBand).
560
-
561
- Args:
562
- enable (bool): Whether to turn on ProcessGroup allocation.
563
- """
564
- state = self._get_fsdp_state()
565
- if (fsdp_param_group := state._fsdp_param_group) is not None:
566
- fsdp_param_group.allocate_memory_from_process_group = enable
567
-
568
- def _set_unshard_async_op(self, async_op: bool):
569
- """
570
- Sets whether to use ``async_op=True`` or ``False`` for the pre-forward
571
- and pre-backward unshard op. This defaults to ``False`` but can be set
572
- to ``True`` with this method.
573
-
574
- Setting this to ``True`` allows the all-gather allocations to happen in
575
- the default stream, avoiding inter-stream memory fragmentation.
576
- However, you must use explicit prefetching (e.g. via :meth:`unshard`)
577
- in forward to still get overlap, and the pre-all-gather ops like dtype
578
- casting and copy-in will not overlap with compute.
579
- """
580
- self_module = cast(nn.Module, self)
581
- for module in self_module.modules():
582
- if isinstance(module, FSDPModule):
583
- state = module._get_fsdp_state()
584
- if fsdp_param_group := state._fsdp_param_group:
585
- fsdp_param_group.unshard_async_op = async_op
586
-
587
- def _get_fsdp_state(self) -> FSDPState:
588
- if (state := _get_module_fsdp_state(cast(nn.Module, self))) is None:
589
- raise AssertionError(f"No FSDP state found on {self}")
590
- return state
591
-
592
- def _apply(self, *args: Any, **kwargs: Any) -> Any:
593
- # Reshard to ensure that sharded parameters are registered
594
- self.reshard()
595
- ret = super()._apply(*args, **kwargs) # type: ignore[misc]
596
- state = self._get_fsdp_state()
597
- if not (fsdp_param_group := state._fsdp_param_group):
598
- return ret
599
- # TODO: Remove this padding logic once DTensor pads the local tensor:
600
- # https://github.com/pytorch/pytorch/issues/113045
601
- with torch.no_grad():
602
- for fsdp_param in fsdp_param_group.fsdp_params:
603
- fsdp_param.reset_sharded_param()
604
- return ret
605
-
606
-
607
- class UnshardHandle:
608
- """
609
- A handle to wait on a :meth:`FSDPModule.unshard` op.
610
- """
611
-
612
- def wait(self) -> None:
613
- """
614
- Waits on the unshard op. This ensures that the current stream can use
615
- the unsharded parameters, which are now registered to the module.
616
- """
617
- return
618
-
619
-
620
- class _UnshardHandleImpl(UnshardHandle):
621
- def __init__(self, fsdp_param_group: Optional[FSDPParamGroup]):
622
- self._fsdp_param_group = fsdp_param_group
623
-
624
- def wait(self):
625
- if self._fsdp_param_group is not None:
626
- self._fsdp_param_group.wait_for_unshard()
627
- # Avoid keeping a reference
628
- self._fsdp_param_group = None
629
-
630
-
631
- def register_fsdp_forward_method(module: nn.Module, method_name: str) -> None:
632
- """
633
- Registers a method on ``module`` to be considered a forward method for
634
- FSDP.
635
-
636
- FSDP all-gathers parameters pre-forward and optionally frees parameters
637
- post-forward (depending on ``reshard_after_forward``). FSDP only knows to
638
- do this for :meth:`nn.Module.forward` by default. This function patches a
639
- user-specified method to run the pre/post-forward hooks before/after the
640
- method, respectively. If ``module`` is not an :class:`FSDPModule`, then
641
- this is a no-op.
642
-
643
- Args:
644
- module (nn.Module): Module to register the forward method on.
645
- method_name (str): Name of the forward method.
646
- """
647
- if not isinstance(module, FSDPModule):
648
- # Make no-op to allow including both when using/not using FSDP
649
- return
650
- if not hasattr(module, method_name):
651
- raise ValueError(f"{type(module)} does not have a method {method_name}")
652
- orig_method = getattr(module, method_name)
653
-
654
- @functools.wraps(orig_method)
655
- def wrapped_method(self, *args, **kwargs):
656
- fsdp_state = self._get_fsdp_state()
657
- args, kwargs = fsdp_state._pre_forward(self, args, kwargs)
658
- out = orig_method(*args, **kwargs)
659
- return fsdp_state._post_forward(self, args, out)
660
-
661
- # Use `__get__` to make `wrapped_method` an instance method
662
- setattr(
663
- module,
664
- method_name,
665
- wrapped_method.__get__(module, type(module)), # type:ignore[attr-defined]
666
- )
667
-
668
-
669
- def _assert_all_fsdp_modules(modules: Iterable[Any]) -> None:
670
- for module in modules:
671
- if not isinstance(module, FSDPModule):
672
- raise ValueError(f"Expects FSDPModule but got {type(module)}: {module}")