[feature] new args learning_rate max_epochs

This commit is contained in:
Yiqing-Zhou 2023-05-09 00:02:29 +08:00
parent dc2941f790
commit 3f92bbbaa2
2 changed files with 26 additions and 3 deletions

View File

@ -9,10 +9,16 @@ from utils import init_model
class LitModule(pl.LightningModule): class LitModule(pl.LightningModule):
def __init__(self, model_name: str, use_tril_attention_mask: str = False): def __init__(
self,
model_name: str,
learning_rate: float = 0.0001,
use_tril_attention_mask: str = False,
):
super().__init__() super().__init__()
self.save_hyperparameters() self.save_hyperparameters()
self.llm = self.register_core_module(init_model(model_name)) self.llm = self.register_core_module(init_model(model_name))
self.learning_rate = learning_rate
self.use_tril_attention_mask = use_tril_attention_mask self.use_tril_attention_mask = use_tril_attention_mask
self.metric_loss = torchmetrics.MeanMetric() self.metric_loss = torchmetrics.MeanMetric()
self.metric_accuracy = torchmetrics.Accuracy( self.metric_accuracy = torchmetrics.Accuracy(
@ -62,7 +68,9 @@ class LitModule(pl.LightningModule):
self.log('accuracy', self.metric_accuracy, rank_zero_only=True) self.log('accuracy', self.metric_accuracy, rank_zero_only=True)
def configure_optimizers(self): def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=0.0001) optimizer = torch.optim.AdamW(
self.trainer.model.parameters(), lr=self.learning_rate
)
return optimizer return optimizer
def configure_callbacks(self): def configure_callbacks(self):

View File

@ -80,6 +80,12 @@ def parse_args():
help="Name of or path to model", help="Name of or path to model",
default='gpt2', default='gpt2',
) )
parser.add_argument(
"--learning_rate",
type=float,
help="Learning rate",
default=0.0001,
)
parser.add_argument( parser.add_argument(
"--use_tril_attention_mask", "--use_tril_attention_mask",
help="Use tril attention mask during training", help="Use tril attention mask during training",
@ -124,6 +130,12 @@ def parse_args():
help="Number of data processes", help="Number of data processes",
default=16, default=16,
) )
parser.add_argument(
"--max_epochs",
type=int,
help="Max epochs",
default=None,
)
parser.add_argument( parser.add_argument(
"--strategy", "--strategy",
type=str, type=str,
@ -155,7 +167,9 @@ if __name__ == '__main__':
set_seed(args.seed) set_seed(args.seed)
# lightning module # lightning module
lit_module = LitModule(args.model_name, args.use_tril_attention_mask) lit_module = LitModule(
args.model_name, args.learning_rate, args.use_tril_attention_mask
)
# datasets # datasets
tokenizer = load_tokenizer(args.tokenizer_name_or_path) tokenizer = load_tokenizer(args.tokenizer_name_or_path)
@ -204,6 +218,7 @@ if __name__ == '__main__':
log_every_n_steps=5, log_every_n_steps=5,
accumulate_grad_batches=args.accumulate_grad_batches, accumulate_grad_batches=args.accumulate_grad_batches,
strategy=args.strategy, strategy=args.strategy,
max_epochs=args.max_epochs,
) )
lit_trainer.fit( lit_trainer.fit(
lit_module, lit_module,