[fix] add patch to fix DeepSpeedStrategy offload 'zero_force_ds_cpu_optimizer' issue
This commit is contained in:
parent
8a5e2043bb
commit
5e6b747baf
|
@ -68,6 +68,18 @@ class LitModule(pl.LightningModule):
|
||||||
self.log('accuracy', self.metric_accuracy, rank_zero_only=True)
|
self.log('accuracy', self.metric_accuracy, rank_zero_only=True)
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
|
strategy = self.trainer.strategy
|
||||||
|
if isinstance(strategy, pl.strategies.DeepSpeedStrategy):
|
||||||
|
assert "optimizer" not in strategy.config
|
||||||
|
zero_config = strategy.config.get("zero_optimization")
|
||||||
|
if zero_config is not None:
|
||||||
|
if "offload_optimizer" in zero_config:
|
||||||
|
import deepspeed
|
||||||
|
|
||||||
|
optimizer = deepspeed.ops.adam.DeepSpeedCPUAdam(
|
||||||
|
self.trainer.model.parameters(), lr=self.learning_rate
|
||||||
|
)
|
||||||
|
return optimizer
|
||||||
optimizer = torch.optim.AdamW(
|
optimizer = torch.optim.AdamW(
|
||||||
self.trainer.model.parameters(), lr=self.learning_rate
|
self.trainer.model.parameters(), lr=self.learning_rate
|
||||||
)
|
)
|
||||||
|
|
133
lit_patches.py
133
lit_patches.py
|
@ -1,4 +1,4 @@
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
|
@ -47,5 +47,134 @@ class FSDPStrategy(pl.strategies.FSDPStrategy):
|
||||||
cls._registered_strategies.append("fsdp_cpu_offload")
|
cls._registered_strategies.append("fsdp_cpu_offload")
|
||||||
|
|
||||||
|
|
||||||
def apply_all_patches():
|
class DeepSpeedStrategy(pl.strategies.DeepSpeedStrategy):
|
||||||
|
def _create_default_config(
|
||||||
|
self,
|
||||||
|
zero_optimization: bool,
|
||||||
|
zero_allow_untested_optimizer: bool,
|
||||||
|
logging_batch_size_per_gpu: Union[str, int],
|
||||||
|
partition_activations: bool,
|
||||||
|
cpu_checkpointing: bool,
|
||||||
|
contiguous_memory_optimization: bool,
|
||||||
|
synchronize_checkpoint_boundary: bool,
|
||||||
|
offload_optimizer: bool,
|
||||||
|
offload_parameters: bool,
|
||||||
|
nvme_path: str,
|
||||||
|
offload_params_device: str,
|
||||||
|
params_buffer_count: int,
|
||||||
|
params_buffer_size: int,
|
||||||
|
max_in_cpu: int,
|
||||||
|
offload_optimizer_device: str,
|
||||||
|
optimizer_buffer_count: int,
|
||||||
|
pin_memory: bool,
|
||||||
|
block_size: int,
|
||||||
|
queue_depth: int,
|
||||||
|
single_submit: bool,
|
||||||
|
overlap_events: bool,
|
||||||
|
thread_count: int,
|
||||||
|
**zero_kwargs: Any,
|
||||||
|
) -> Dict:
|
||||||
|
cfg = super()._create_default_config(
|
||||||
|
zero_optimization,
|
||||||
|
zero_allow_untested_optimizer,
|
||||||
|
logging_batch_size_per_gpu,
|
||||||
|
partition_activations,
|
||||||
|
cpu_checkpointing,
|
||||||
|
contiguous_memory_optimization,
|
||||||
|
synchronize_checkpoint_boundary,
|
||||||
|
offload_optimizer,
|
||||||
|
offload_parameters,
|
||||||
|
nvme_path,
|
||||||
|
offload_params_device,
|
||||||
|
params_buffer_count,
|
||||||
|
params_buffer_size,
|
||||||
|
max_in_cpu,
|
||||||
|
offload_optimizer_device,
|
||||||
|
optimizer_buffer_count,
|
||||||
|
pin_memory,
|
||||||
|
block_size,
|
||||||
|
queue_depth,
|
||||||
|
single_submit,
|
||||||
|
overlap_events,
|
||||||
|
thread_count,
|
||||||
|
**zero_kwargs,
|
||||||
|
)
|
||||||
|
if zero_optimization:
|
||||||
|
if offload_parameters:
|
||||||
|
cfg = {
|
||||||
|
"zero_force_ds_cpu_optimizer": False,
|
||||||
|
**cfg,
|
||||||
|
}
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_strategies(cls, strategy_registry: Dict) -> None:
|
||||||
|
strategy_registry.register(
|
||||||
|
"deepspeed",
|
||||||
|
cls,
|
||||||
|
description="Default DeepSpeed Strategy",
|
||||||
|
override=True,
|
||||||
|
)
|
||||||
|
strategy_registry.register(
|
||||||
|
"deepspeed_stage_1",
|
||||||
|
cls,
|
||||||
|
description="DeepSpeed with ZeRO Stage 1 enabled",
|
||||||
|
stage=1,
|
||||||
|
override=True,
|
||||||
|
)
|
||||||
|
strategy_registry.register(
|
||||||
|
"deepspeed_stage_2",
|
||||||
|
cls,
|
||||||
|
description="DeepSpeed with ZeRO Stage 2 enabled",
|
||||||
|
stage=2,
|
||||||
|
override=True,
|
||||||
|
)
|
||||||
|
strategy_registry.register(
|
||||||
|
"deepspeed_stage_2_offload",
|
||||||
|
cls,
|
||||||
|
description="DeepSpeed ZeRO Stage 2 and CPU Offload",
|
||||||
|
stage=2,
|
||||||
|
offload_optimizer=True,
|
||||||
|
override=True,
|
||||||
|
)
|
||||||
|
strategy_registry.register(
|
||||||
|
"deepspeed_stage_3",
|
||||||
|
cls,
|
||||||
|
description="DeepSpeed ZeRO Stage 3",
|
||||||
|
stage=3,
|
||||||
|
override=True,
|
||||||
|
)
|
||||||
|
strategy_registry.register(
|
||||||
|
"deepspeed_stage_3_offload",
|
||||||
|
cls,
|
||||||
|
description="DeepSpeed ZeRO Stage 3 and CPU Offload",
|
||||||
|
stage=3,
|
||||||
|
offload_optimizer=True,
|
||||||
|
offload_parameters=True,
|
||||||
|
override=True,
|
||||||
|
)
|
||||||
|
strategy_registry.register(
|
||||||
|
"deepspeed_stage_3_offload_nvme",
|
||||||
|
cls,
|
||||||
|
description="DeepSpeed ZeRO Stage 3 and NVMe Offload",
|
||||||
|
stage=3,
|
||||||
|
offload_optimizer=True,
|
||||||
|
offload_parameters=True,
|
||||||
|
remote_device="nvme",
|
||||||
|
offload_params_device="nvme",
|
||||||
|
offload_optimizer_device="nvme",
|
||||||
|
override=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_fsdp_strategy_patch():
|
||||||
FSDPStrategy.register_strategies(pl.strategies.StrategyRegistry)
|
FSDPStrategy.register_strategies(pl.strategies.StrategyRegistry)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_deepspeed_strategy_patch():
|
||||||
|
DeepSpeedStrategy.register_strategies(pl.strategies.StrategyRegistry)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_all_patches():
|
||||||
|
apply_fsdp_strategy_patch()
|
||||||
|
apply_deepspeed_strategy_patch()
|
||||||
|
|
|
@ -150,7 +150,7 @@ def parse_args():
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--seed",
|
"--seed",
|
||||||
type=str,
|
type=int,
|
||||||
help="Random seed",
|
help="Random seed",
|
||||||
default=42,
|
default=42,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue