gpt-pretrain/lit_patches.py

52 lines
1.6 KiB
Python
Raw Normal View History

from typing import Any, Dict, Optional
import pytorch_lightning as pl
from torch.nn import Module
class FSDPStrategy(pl.strategies.FSDPStrategy):
@property
def model(self) -> Optional[Module]:
"""Returns the potentially wrapped LightningModule."""
return self._model
@model.setter
def model(self, new_model: Optional[Module]) -> None:
self._model = new_model
def lightning_module_state_dict(self) -> Dict[str, Any]:
"""Returns model state."""
if self.model is None:
assert self.lightning_module is not None
return self.lightning_module.state_dict()
else:
prefix = "_forward_module."
state_dict = self.model.state_dict()
state_dict = {k[len(prefix) :]: v for k, v in state_dict.items()}
return state_dict
@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
if not pl.strategies.fsdp._fsdp_available:
return
strategy_registry.register(
"fsdp",
cls,
description="Fully Sharded Data Parallel (FSDP) training",
override=True,
)
cls._registered_strategies.append("fsdp")
strategy_registry.register(
"fsdp_cpu_offload",
cls,
description="Fully Sharded Data Parallel (FSDP) training with Full Sharding and CPU Offloading",
cpu_offload=True,
override=True,
)
cls._registered_strategies.append("fsdp_cpu_offload")
def apply_all_patches():
FSDPStrategy.register_strategies(pl.strategies.StrategyRegistry)