[fix] add patch to fix FSDPStrategy checkpoint issue
This commit is contained in:
parent
5392a845f7
commit
70ff2acaf0
|
@ -0,0 +1,51 @@
|
|||
from typing import Any, Dict, Optional
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from torch.nn import Module
|
||||
|
||||
|
||||
class FSDPStrategy(pl.strategies.FSDPStrategy):
|
||||
@property
|
||||
def model(self) -> Optional[Module]:
|
||||
"""Returns the potentially wrapped LightningModule."""
|
||||
return self._model
|
||||
|
||||
@model.setter
|
||||
def model(self, new_model: Optional[Module]) -> None:
|
||||
self._model = new_model
|
||||
|
||||
def lightning_module_state_dict(self) -> Dict[str, Any]:
|
||||
"""Returns model state."""
|
||||
if self.model is None:
|
||||
assert self.lightning_module is not None
|
||||
return self.lightning_module.state_dict()
|
||||
else:
|
||||
prefix = "_forward_module."
|
||||
state_dict = self.model.state_dict()
|
||||
state_dict = {k[len(prefix) :]: v for k, v in state_dict.items()}
|
||||
return state_dict
|
||||
|
||||
@classmethod
|
||||
def register_strategies(cls, strategy_registry: Dict) -> None:
|
||||
if not pl.strategies.fsdp._fsdp_available:
|
||||
return
|
||||
strategy_registry.register(
|
||||
"fsdp",
|
||||
cls,
|
||||
description="Fully Sharded Data Parallel (FSDP) training",
|
||||
override=True,
|
||||
)
|
||||
cls._registered_strategies.append("fsdp")
|
||||
|
||||
strategy_registry.register(
|
||||
"fsdp_cpu_offload",
|
||||
cls,
|
||||
description="Fully Sharded Data Parallel (FSDP) training with Full Sharding and CPU Offloading",
|
||||
cpu_offload=True,
|
||||
override=True,
|
||||
)
|
||||
cls._registered_strategies.append("fsdp_cpu_offload")
|
||||
|
||||
|
||||
def apply_all_patches():
|
||||
FSDPStrategy.register_strategies(pl.strategies.StrategyRegistry)
|
|
@ -15,6 +15,7 @@ from transformers import (
|
|||
)
|
||||
|
||||
from lit_module import LitModule
|
||||
from lit_patches import apply_all_patches
|
||||
from utils import load_tokenizer
|
||||
|
||||
|
||||
|
@ -189,6 +190,7 @@ if __name__ == '__main__':
|
|||
)
|
||||
|
||||
# trainer
|
||||
apply_all_patches()
|
||||
torch.set_float32_matmul_precision('medium')
|
||||
if args.bf16:
|
||||
precision = 'bf16-mixed'
|
||||
|
|
Loading…
Reference in New Issue