[feature] new args learning_rate max_epochs
This commit is contained in:
parent
dc2941f790
commit
3f92bbbaa2
|
@ -9,10 +9,16 @@ from utils import init_model
|
|||
|
||||
|
||||
class LitModule(pl.LightningModule):
|
||||
def __init__(self, model_name: str, use_tril_attention_mask: str = False):
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
learning_rate: float = 0.0001,
|
||||
use_tril_attention_mask: str = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
self.llm = self.register_core_module(init_model(model_name))
|
||||
self.learning_rate = learning_rate
|
||||
self.use_tril_attention_mask = use_tril_attention_mask
|
||||
self.metric_loss = torchmetrics.MeanMetric()
|
||||
self.metric_accuracy = torchmetrics.Accuracy(
|
||||
|
@ -62,7 +68,9 @@ class LitModule(pl.LightningModule):
|
|||
self.log('accuracy', self.metric_accuracy, rank_zero_only=True)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=0.0001)
|
||||
optimizer = torch.optim.AdamW(
|
||||
self.trainer.model.parameters(), lr=self.learning_rate
|
||||
)
|
||||
return optimizer
|
||||
|
||||
def configure_callbacks(self):
|
||||
|
|
17
lit_train.py
17
lit_train.py
|
@ -80,6 +80,12 @@ def parse_args():
|
|||
help="Name of or path to model",
|
||||
default='gpt2',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
help="Learning rate",
|
||||
default=0.0001,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_tril_attention_mask",
|
||||
help="Use tril attention mask during training",
|
||||
|
@ -124,6 +130,12 @@ def parse_args():
|
|||
help="Number of data processes",
|
||||
default=16,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_epochs",
|
||||
type=int,
|
||||
help="Max epochs",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--strategy",
|
||||
type=str,
|
||||
|
@ -155,7 +167,9 @@ if __name__ == '__main__':
|
|||
set_seed(args.seed)
|
||||
|
||||
# lightning module
|
||||
lit_module = LitModule(args.model_name, args.use_tril_attention_mask)
|
||||
lit_module = LitModule(
|
||||
args.model_name, args.learning_rate, args.use_tril_attention_mask
|
||||
)
|
||||
|
||||
# datasets
|
||||
tokenizer = load_tokenizer(args.tokenizer_name_or_path)
|
||||
|
@ -204,6 +218,7 @@ if __name__ == '__main__':
|
|||
log_every_n_steps=5,
|
||||
accumulate_grad_batches=args.accumulate_grad_batches,
|
||||
strategy=args.strategy,
|
||||
max_epochs=args.max_epochs,
|
||||
)
|
||||
lit_trainer.fit(
|
||||
lit_module,
|
||||
|
|
Loading…
Reference in New Issue