[feature] persistent_workers; new args accumulate_grad_batches strategy
This commit is contained in:
parent
45fa065530
commit
9b7f9b9d60
19
train.py
19
train.py
|
@ -8,7 +8,6 @@ import datasets
|
|||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchmetrics
|
||||
from pytorch_lightning import strategies
|
||||
from torch.utils.data import ConcatDataset, DataLoader
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
|
@ -131,12 +130,24 @@ def parse_args():
|
|||
help="Batch size of validating",
|
||||
default=16,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--accumulate_grad_batches",
|
||||
type=int,
|
||||
help="Accumulate grad batches",
|
||||
default=32,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_proc",
|
||||
type=str,
|
||||
help="Number of data processes",
|
||||
default=16,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--strategy",
|
||||
type=str,
|
||||
help="Name of pytorch lightning distribution strategy",
|
||||
default='fsdp',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume_from_ckpt_path",
|
||||
type=str,
|
||||
|
@ -256,6 +267,7 @@ if __name__ == '__main__':
|
|||
batch_size=args.train_batch_size,
|
||||
num_workers=args.num_proc,
|
||||
collate_fn=DefaultDataCollator(),
|
||||
persistent_workers=True,
|
||||
shuffle=True,
|
||||
)
|
||||
val_dataloader = DataLoader(
|
||||
|
@ -263,6 +275,7 @@ if __name__ == '__main__':
|
|||
batch_size=args.val_batch_size,
|
||||
num_workers=args.num_proc,
|
||||
collate_fn=DefaultDataCollator(),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
# trainer
|
||||
|
@ -277,8 +290,8 @@ if __name__ == '__main__':
|
|||
accelerator='gpu',
|
||||
precision=precision,
|
||||
log_every_n_steps=5,
|
||||
accumulate_grad_batches=32,
|
||||
strategy=strategies.DeepSpeedStrategy(stage=3),
|
||||
accumulate_grad_batches=args.accumulate_grad_batches,
|
||||
strategy=args.strategy,
|
||||
)
|
||||
lit_trainer.fit(
|
||||
lit_module,
|
||||
|
|
Loading…
Reference in New Issue