[fix] add patch to fix DeepSpeedStrategy offload 'zero_force_ds_cpu_optimizer' issue

This commit is contained in:
Yiqing-Zhou 2023-05-09 23:00:28 +08:00
parent 8a5e2043bb
commit 5e6b747baf
3 changed files with 144 additions and 3 deletions

View File

@ -68,6 +68,18 @@ class LitModule(pl.LightningModule):
self.log('accuracy', self.metric_accuracy, rank_zero_only=True) self.log('accuracy', self.metric_accuracy, rank_zero_only=True)
def configure_optimizers(self): def configure_optimizers(self):
strategy = self.trainer.strategy
if isinstance(strategy, pl.strategies.DeepSpeedStrategy):
assert "optimizer" not in strategy.config
zero_config = strategy.config.get("zero_optimization")
if zero_config is not None:
if "offload_optimizer" in zero_config:
import deepspeed
optimizer = deepspeed.ops.adam.DeepSpeedCPUAdam(
self.trainer.model.parameters(), lr=self.learning_rate
)
return optimizer
optimizer = torch.optim.AdamW( optimizer = torch.optim.AdamW(
self.trainer.model.parameters(), lr=self.learning_rate self.trainer.model.parameters(), lr=self.learning_rate
) )

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, Optional from typing import Any, Dict, Optional, Union
import pytorch_lightning as pl import pytorch_lightning as pl
from torch.nn import Module from torch.nn import Module
@ -47,5 +47,134 @@ class FSDPStrategy(pl.strategies.FSDPStrategy):
cls._registered_strategies.append("fsdp_cpu_offload") cls._registered_strategies.append("fsdp_cpu_offload")
def apply_all_patches(): class DeepSpeedStrategy(pl.strategies.DeepSpeedStrategy):
def _create_default_config(
self,
zero_optimization: bool,
zero_allow_untested_optimizer: bool,
logging_batch_size_per_gpu: Union[str, int],
partition_activations: bool,
cpu_checkpointing: bool,
contiguous_memory_optimization: bool,
synchronize_checkpoint_boundary: bool,
offload_optimizer: bool,
offload_parameters: bool,
nvme_path: str,
offload_params_device: str,
params_buffer_count: int,
params_buffer_size: int,
max_in_cpu: int,
offload_optimizer_device: str,
optimizer_buffer_count: int,
pin_memory: bool,
block_size: int,
queue_depth: int,
single_submit: bool,
overlap_events: bool,
thread_count: int,
**zero_kwargs: Any,
) -> Dict:
cfg = super()._create_default_config(
zero_optimization,
zero_allow_untested_optimizer,
logging_batch_size_per_gpu,
partition_activations,
cpu_checkpointing,
contiguous_memory_optimization,
synchronize_checkpoint_boundary,
offload_optimizer,
offload_parameters,
nvme_path,
offload_params_device,
params_buffer_count,
params_buffer_size,
max_in_cpu,
offload_optimizer_device,
optimizer_buffer_count,
pin_memory,
block_size,
queue_depth,
single_submit,
overlap_events,
thread_count,
**zero_kwargs,
)
if zero_optimization:
if offload_parameters:
cfg = {
"zero_force_ds_cpu_optimizer": False,
**cfg,
}
return cfg
@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register(
"deepspeed",
cls,
description="Default DeepSpeed Strategy",
override=True,
)
strategy_registry.register(
"deepspeed_stage_1",
cls,
description="DeepSpeed with ZeRO Stage 1 enabled",
stage=1,
override=True,
)
strategy_registry.register(
"deepspeed_stage_2",
cls,
description="DeepSpeed with ZeRO Stage 2 enabled",
stage=2,
override=True,
)
strategy_registry.register(
"deepspeed_stage_2_offload",
cls,
description="DeepSpeed ZeRO Stage 2 and CPU Offload",
stage=2,
offload_optimizer=True,
override=True,
)
strategy_registry.register(
"deepspeed_stage_3",
cls,
description="DeepSpeed ZeRO Stage 3",
stage=3,
override=True,
)
strategy_registry.register(
"deepspeed_stage_3_offload",
cls,
description="DeepSpeed ZeRO Stage 3 and CPU Offload",
stage=3,
offload_optimizer=True,
offload_parameters=True,
override=True,
)
strategy_registry.register(
"deepspeed_stage_3_offload_nvme",
cls,
description="DeepSpeed ZeRO Stage 3 and NVMe Offload",
stage=3,
offload_optimizer=True,
offload_parameters=True,
remote_device="nvme",
offload_params_device="nvme",
offload_optimizer_device="nvme",
override=True,
)
def apply_fsdp_strategy_patch():
FSDPStrategy.register_strategies(pl.strategies.StrategyRegistry) FSDPStrategy.register_strategies(pl.strategies.StrategyRegistry)
def apply_deepspeed_strategy_patch():
DeepSpeedStrategy.register_strategies(pl.strategies.StrategyRegistry)
def apply_all_patches():
apply_fsdp_strategy_patch()
apply_deepspeed_strategy_patch()

View File

@ -150,7 +150,7 @@ def parse_args():
) )
parser.add_argument( parser.add_argument(
"--seed", "--seed",
type=str, type=int,
help="Random seed", help="Random seed",
default=42, default=42,
) )