From 70ff2acaf000f0d5bfe2bb7830a0a55da0c8a7f4 Mon Sep 17 00:00:00 2001 From: Yiqing-Zhou Date: Sun, 7 May 2023 16:51:57 +0800 Subject: [PATCH] [fix] add patch to fix FSDPStrategy checkpoint issue --- lit_patches.py | 51 ++++++++++++++++++++++++++++++++++++++++++++++++++ lit_train.py | 2 ++ 2 files changed, 53 insertions(+) create mode 100644 lit_patches.py diff --git a/lit_patches.py b/lit_patches.py new file mode 100644 index 0000000..6f9d1f1 --- /dev/null +++ b/lit_patches.py @@ -0,0 +1,51 @@ +from typing import Any, Dict, Optional + +import pytorch_lightning as pl +from torch.nn import Module + + +class FSDPStrategy(pl.strategies.FSDPStrategy): + @property + def model(self) -> Optional[Module]: + """Returns the potentially wrapped LightningModule.""" + return self._model + + @model.setter + def model(self, new_model: Optional[Module]) -> None: + self._model = new_model + + def lightning_module_state_dict(self) -> Dict[str, Any]: + """Returns model state.""" + if self.model is None: + assert self.lightning_module is not None + return self.lightning_module.state_dict() + else: + prefix = "_forward_module." + state_dict = self.model.state_dict() + state_dict = {k[len(prefix) :]: v for k, v in state_dict.items()} + return state_dict + + @classmethod + def register_strategies(cls, strategy_registry: Dict) -> None: + if not pl.strategies.fsdp._fsdp_available: + return + strategy_registry.register( + "fsdp", + cls, + description="Fully Sharded Data Parallel (FSDP) training", + override=True, + ) + cls._registered_strategies.append("fsdp") + + strategy_registry.register( + "fsdp_cpu_offload", + cls, + description="Fully Sharded Data Parallel (FSDP) training with Full Sharding and CPU Offloading", + cpu_offload=True, + override=True, + ) + cls._registered_strategies.append("fsdp_cpu_offload") + + +def apply_all_patches(): + FSDPStrategy.register_strategies(pl.strategies.StrategyRegistry) diff --git a/lit_train.py b/lit_train.py index 6c65561..3e16cf0 100644 --- a/lit_train.py +++ b/lit_train.py @@ -15,6 +15,7 @@ from transformers import ( ) from lit_module import LitModule +from lit_patches import apply_all_patches from utils import load_tokenizer @@ -189,6 +190,7 @@ if __name__ == '__main__': ) # trainer + apply_all_patches() torch.set_float32_matmul_precision('medium') if args.bf16: precision = 'bf16-mixed'