181 lines
5.5 KiB
Python
181 lines
5.5 KiB
Python
from typing import Any, Dict, Optional, Union
|
|
|
|
import pytorch_lightning as pl
|
|
from torch.nn import Module
|
|
|
|
|
|
class FSDPStrategy(pl.strategies.FSDPStrategy):
|
|
@property
|
|
def model(self) -> Optional[Module]:
|
|
"""Returns the potentially wrapped LightningModule."""
|
|
return self._model
|
|
|
|
@model.setter
|
|
def model(self, new_model: Optional[Module]) -> None:
|
|
self._model = new_model
|
|
|
|
def lightning_module_state_dict(self) -> Dict[str, Any]:
|
|
"""Returns model state."""
|
|
if self.model is None:
|
|
assert self.lightning_module is not None
|
|
return self.lightning_module.state_dict()
|
|
else:
|
|
prefix = "_forward_module."
|
|
state_dict = self.model.state_dict()
|
|
state_dict = {k[len(prefix) :]: v for k, v in state_dict.items()}
|
|
return state_dict
|
|
|
|
@classmethod
|
|
def register_strategies(cls, strategy_registry: Dict) -> None:
|
|
if not pl.strategies.fsdp._fsdp_available:
|
|
return
|
|
strategy_registry.register(
|
|
"fsdp",
|
|
cls,
|
|
description="Fully Sharded Data Parallel (FSDP) training",
|
|
override=True,
|
|
)
|
|
cls._registered_strategies.append("fsdp")
|
|
|
|
strategy_registry.register(
|
|
"fsdp_cpu_offload",
|
|
cls,
|
|
description="Fully Sharded Data Parallel (FSDP) training with Full Sharding and CPU Offloading",
|
|
cpu_offload=True,
|
|
override=True,
|
|
)
|
|
cls._registered_strategies.append("fsdp_cpu_offload")
|
|
|
|
|
|
class DeepSpeedStrategy(pl.strategies.DeepSpeedStrategy):
|
|
def _create_default_config(
|
|
self,
|
|
zero_optimization: bool,
|
|
zero_allow_untested_optimizer: bool,
|
|
logging_batch_size_per_gpu: Union[str, int],
|
|
partition_activations: bool,
|
|
cpu_checkpointing: bool,
|
|
contiguous_memory_optimization: bool,
|
|
synchronize_checkpoint_boundary: bool,
|
|
offload_optimizer: bool,
|
|
offload_parameters: bool,
|
|
nvme_path: str,
|
|
offload_params_device: str,
|
|
params_buffer_count: int,
|
|
params_buffer_size: int,
|
|
max_in_cpu: int,
|
|
offload_optimizer_device: str,
|
|
optimizer_buffer_count: int,
|
|
pin_memory: bool,
|
|
block_size: int,
|
|
queue_depth: int,
|
|
single_submit: bool,
|
|
overlap_events: bool,
|
|
thread_count: int,
|
|
**zero_kwargs: Any,
|
|
) -> Dict:
|
|
cfg = super()._create_default_config(
|
|
zero_optimization,
|
|
zero_allow_untested_optimizer,
|
|
logging_batch_size_per_gpu,
|
|
partition_activations,
|
|
cpu_checkpointing,
|
|
contiguous_memory_optimization,
|
|
synchronize_checkpoint_boundary,
|
|
offload_optimizer,
|
|
offload_parameters,
|
|
nvme_path,
|
|
offload_params_device,
|
|
params_buffer_count,
|
|
params_buffer_size,
|
|
max_in_cpu,
|
|
offload_optimizer_device,
|
|
optimizer_buffer_count,
|
|
pin_memory,
|
|
block_size,
|
|
queue_depth,
|
|
single_submit,
|
|
overlap_events,
|
|
thread_count,
|
|
**zero_kwargs,
|
|
)
|
|
if zero_optimization:
|
|
if offload_parameters:
|
|
cfg = {
|
|
"zero_force_ds_cpu_optimizer": False,
|
|
**cfg,
|
|
}
|
|
return cfg
|
|
|
|
@classmethod
|
|
def register_strategies(cls, strategy_registry: Dict) -> None:
|
|
strategy_registry.register(
|
|
"deepspeed",
|
|
cls,
|
|
description="Default DeepSpeed Strategy",
|
|
override=True,
|
|
)
|
|
strategy_registry.register(
|
|
"deepspeed_stage_1",
|
|
cls,
|
|
description="DeepSpeed with ZeRO Stage 1 enabled",
|
|
stage=1,
|
|
override=True,
|
|
)
|
|
strategy_registry.register(
|
|
"deepspeed_stage_2",
|
|
cls,
|
|
description="DeepSpeed with ZeRO Stage 2 enabled",
|
|
stage=2,
|
|
override=True,
|
|
)
|
|
strategy_registry.register(
|
|
"deepspeed_stage_2_offload",
|
|
cls,
|
|
description="DeepSpeed ZeRO Stage 2 and CPU Offload",
|
|
stage=2,
|
|
offload_optimizer=True,
|
|
override=True,
|
|
)
|
|
strategy_registry.register(
|
|
"deepspeed_stage_3",
|
|
cls,
|
|
description="DeepSpeed ZeRO Stage 3",
|
|
stage=3,
|
|
override=True,
|
|
)
|
|
strategy_registry.register(
|
|
"deepspeed_stage_3_offload",
|
|
cls,
|
|
description="DeepSpeed ZeRO Stage 3 and CPU Offload",
|
|
stage=3,
|
|
offload_optimizer=True,
|
|
offload_parameters=True,
|
|
override=True,
|
|
)
|
|
strategy_registry.register(
|
|
"deepspeed_stage_3_offload_nvme",
|
|
cls,
|
|
description="DeepSpeed ZeRO Stage 3 and NVMe Offload",
|
|
stage=3,
|
|
offload_optimizer=True,
|
|
offload_parameters=True,
|
|
remote_device="nvme",
|
|
offload_params_device="nvme",
|
|
offload_optimizer_device="nvme",
|
|
override=True,
|
|
)
|
|
|
|
|
|
def apply_fsdp_strategy_patch():
|
|
FSDPStrategy.register_strategies(pl.strategies.StrategyRegistry)
|
|
|
|
|
|
def apply_deepspeed_strategy_patch():
|
|
DeepSpeedStrategy.register_strategies(pl.strategies.StrategyRegistry)
|
|
|
|
|
|
def apply_all_patches():
|
|
apply_fsdp_strategy_patch()
|
|
apply_deepspeed_strategy_patch()
|