gpt-pretrain/lit_patches.py

181 lines
5.5 KiB
Python
Raw Normal View History

from typing import Any, Dict, Optional, Union
import pytorch_lightning as pl
from torch.nn import Module
class FSDPStrategy(pl.strategies.FSDPStrategy):
@property
def model(self) -> Optional[Module]:
"""Returns the potentially wrapped LightningModule."""
return self._model
@model.setter
def model(self, new_model: Optional[Module]) -> None:
self._model = new_model
def lightning_module_state_dict(self) -> Dict[str, Any]:
"""Returns model state."""
if self.model is None:
assert self.lightning_module is not None
return self.lightning_module.state_dict()
else:
prefix = "_forward_module."
state_dict = self.model.state_dict()
state_dict = {k[len(prefix) :]: v for k, v in state_dict.items()}
return state_dict
@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
if not pl.strategies.fsdp._fsdp_available:
return
strategy_registry.register(
"fsdp",
cls,
description="Fully Sharded Data Parallel (FSDP) training",
override=True,
)
cls._registered_strategies.append("fsdp")
strategy_registry.register(
"fsdp_cpu_offload",
cls,
description="Fully Sharded Data Parallel (FSDP) training with Full Sharding and CPU Offloading",
cpu_offload=True,
override=True,
)
cls._registered_strategies.append("fsdp_cpu_offload")
class DeepSpeedStrategy(pl.strategies.DeepSpeedStrategy):
def _create_default_config(
self,
zero_optimization: bool,
zero_allow_untested_optimizer: bool,
logging_batch_size_per_gpu: Union[str, int],
partition_activations: bool,
cpu_checkpointing: bool,
contiguous_memory_optimization: bool,
synchronize_checkpoint_boundary: bool,
offload_optimizer: bool,
offload_parameters: bool,
nvme_path: str,
offload_params_device: str,
params_buffer_count: int,
params_buffer_size: int,
max_in_cpu: int,
offload_optimizer_device: str,
optimizer_buffer_count: int,
pin_memory: bool,
block_size: int,
queue_depth: int,
single_submit: bool,
overlap_events: bool,
thread_count: int,
**zero_kwargs: Any,
) -> Dict:
cfg = super()._create_default_config(
zero_optimization,
zero_allow_untested_optimizer,
logging_batch_size_per_gpu,
partition_activations,
cpu_checkpointing,
contiguous_memory_optimization,
synchronize_checkpoint_boundary,
offload_optimizer,
offload_parameters,
nvme_path,
offload_params_device,
params_buffer_count,
params_buffer_size,
max_in_cpu,
offload_optimizer_device,
optimizer_buffer_count,
pin_memory,
block_size,
queue_depth,
single_submit,
overlap_events,
thread_count,
**zero_kwargs,
)
if zero_optimization:
if offload_parameters:
cfg = {
"zero_force_ds_cpu_optimizer": False,
**cfg,
}
return cfg
@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register(
"deepspeed",
cls,
description="Default DeepSpeed Strategy",
override=True,
)
strategy_registry.register(
"deepspeed_stage_1",
cls,
description="DeepSpeed with ZeRO Stage 1 enabled",
stage=1,
override=True,
)
strategy_registry.register(
"deepspeed_stage_2",
cls,
description="DeepSpeed with ZeRO Stage 2 enabled",
stage=2,
override=True,
)
strategy_registry.register(
"deepspeed_stage_2_offload",
cls,
description="DeepSpeed ZeRO Stage 2 and CPU Offload",
stage=2,
offload_optimizer=True,
override=True,
)
strategy_registry.register(
"deepspeed_stage_3",
cls,
description="DeepSpeed ZeRO Stage 3",
stage=3,
override=True,
)
strategy_registry.register(
"deepspeed_stage_3_offload",
cls,
description="DeepSpeed ZeRO Stage 3 and CPU Offload",
stage=3,
offload_optimizer=True,
offload_parameters=True,
override=True,
)
strategy_registry.register(
"deepspeed_stage_3_offload_nvme",
cls,
description="DeepSpeed ZeRO Stage 3 and NVMe Offload",
stage=3,
offload_optimizer=True,
offload_parameters=True,
remote_device="nvme",
offload_params_device="nvme",
offload_optimizer_device="nvme",
override=True,
)
def apply_fsdp_strategy_patch():
FSDPStrategy.register_strategies(pl.strategies.StrategyRegistry)
def apply_deepspeed_strategy_patch():
DeepSpeedStrategy.register_strategies(pl.strategies.StrategyRegistry)
def apply_all_patches():
apply_fsdp_strategy_patch()
apply_deepspeed_strategy_patch()