[fix] add patch to fix FSDPStrategy checkpoint issue
This commit is contained in:
parent
5392a845f7
commit
70ff2acaf0
|
@ -0,0 +1,51 @@
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
from torch.nn import Module
|
||||||
|
|
||||||
|
|
||||||
|
class FSDPStrategy(pl.strategies.FSDPStrategy):
|
||||||
|
@property
|
||||||
|
def model(self) -> Optional[Module]:
|
||||||
|
"""Returns the potentially wrapped LightningModule."""
|
||||||
|
return self._model
|
||||||
|
|
||||||
|
@model.setter
|
||||||
|
def model(self, new_model: Optional[Module]) -> None:
|
||||||
|
self._model = new_model
|
||||||
|
|
||||||
|
def lightning_module_state_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Returns model state."""
|
||||||
|
if self.model is None:
|
||||||
|
assert self.lightning_module is not None
|
||||||
|
return self.lightning_module.state_dict()
|
||||||
|
else:
|
||||||
|
prefix = "_forward_module."
|
||||||
|
state_dict = self.model.state_dict()
|
||||||
|
state_dict = {k[len(prefix) :]: v for k, v in state_dict.items()}
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_strategies(cls, strategy_registry: Dict) -> None:
|
||||||
|
if not pl.strategies.fsdp._fsdp_available:
|
||||||
|
return
|
||||||
|
strategy_registry.register(
|
||||||
|
"fsdp",
|
||||||
|
cls,
|
||||||
|
description="Fully Sharded Data Parallel (FSDP) training",
|
||||||
|
override=True,
|
||||||
|
)
|
||||||
|
cls._registered_strategies.append("fsdp")
|
||||||
|
|
||||||
|
strategy_registry.register(
|
||||||
|
"fsdp_cpu_offload",
|
||||||
|
cls,
|
||||||
|
description="Fully Sharded Data Parallel (FSDP) training with Full Sharding and CPU Offloading",
|
||||||
|
cpu_offload=True,
|
||||||
|
override=True,
|
||||||
|
)
|
||||||
|
cls._registered_strategies.append("fsdp_cpu_offload")
|
||||||
|
|
||||||
|
|
||||||
|
def apply_all_patches():
|
||||||
|
FSDPStrategy.register_strategies(pl.strategies.StrategyRegistry)
|
|
@ -15,6 +15,7 @@ from transformers import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from lit_module import LitModule
|
from lit_module import LitModule
|
||||||
|
from lit_patches import apply_all_patches
|
||||||
from utils import load_tokenizer
|
from utils import load_tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@ -189,6 +190,7 @@ if __name__ == '__main__':
|
||||||
)
|
)
|
||||||
|
|
||||||
# trainer
|
# trainer
|
||||||
|
apply_all_patches()
|
||||||
torch.set_float32_matmul_precision('medium')
|
torch.set_float32_matmul_precision('medium')
|
||||||
if args.bf16:
|
if args.bf16:
|
||||||
precision = 'bf16-mixed'
|
precision = 'bf16-mixed'
|
||||||
|
|
Loading…
Reference in New Issue