[feature] persistent_workers; new args accumulate_grad_batches strategy
This commit is contained in:
parent
45fa065530
commit
9b7f9b9d60
19
train.py
19
train.py
|
@ -8,7 +8,6 @@ import datasets
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
import torchmetrics
|
import torchmetrics
|
||||||
from pytorch_lightning import strategies
|
|
||||||
from torch.utils.data import ConcatDataset, DataLoader
|
from torch.utils.data import ConcatDataset, DataLoader
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
|
@ -131,12 +130,24 @@ def parse_args():
|
||||||
help="Batch size of validating",
|
help="Batch size of validating",
|
||||||
default=16,
|
default=16,
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--accumulate_grad_batches",
|
||||||
|
type=int,
|
||||||
|
help="Accumulate grad batches",
|
||||||
|
default=32,
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num_proc",
|
"--num_proc",
|
||||||
type=str,
|
type=str,
|
||||||
help="Number of data processes",
|
help="Number of data processes",
|
||||||
default=16,
|
default=16,
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--strategy",
|
||||||
|
type=str,
|
||||||
|
help="Name of pytorch lightning distribution strategy",
|
||||||
|
default='fsdp',
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--resume_from_ckpt_path",
|
"--resume_from_ckpt_path",
|
||||||
type=str,
|
type=str,
|
||||||
|
@ -256,6 +267,7 @@ if __name__ == '__main__':
|
||||||
batch_size=args.train_batch_size,
|
batch_size=args.train_batch_size,
|
||||||
num_workers=args.num_proc,
|
num_workers=args.num_proc,
|
||||||
collate_fn=DefaultDataCollator(),
|
collate_fn=DefaultDataCollator(),
|
||||||
|
persistent_workers=True,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
)
|
)
|
||||||
val_dataloader = DataLoader(
|
val_dataloader = DataLoader(
|
||||||
|
@ -263,6 +275,7 @@ if __name__ == '__main__':
|
||||||
batch_size=args.val_batch_size,
|
batch_size=args.val_batch_size,
|
||||||
num_workers=args.num_proc,
|
num_workers=args.num_proc,
|
||||||
collate_fn=DefaultDataCollator(),
|
collate_fn=DefaultDataCollator(),
|
||||||
|
persistent_workers=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# trainer
|
# trainer
|
||||||
|
@ -277,8 +290,8 @@ if __name__ == '__main__':
|
||||||
accelerator='gpu',
|
accelerator='gpu',
|
||||||
precision=precision,
|
precision=precision,
|
||||||
log_every_n_steps=5,
|
log_every_n_steps=5,
|
||||||
accumulate_grad_batches=32,
|
accumulate_grad_batches=args.accumulate_grad_batches,
|
||||||
strategy=strategies.DeepSpeedStrategy(stage=3),
|
strategy=args.strategy,
|
||||||
)
|
)
|
||||||
lit_trainer.fit(
|
lit_trainer.fit(
|
||||||
lit_module,
|
lit_module,
|
||||||
|
|
Loading…
Reference in New Issue