donot use auto optimizer.
This commit is contained in:
parent
990e27ba15
commit
f411b1cc5e
|
@ -172,6 +172,7 @@ class LightModule(pl.LightningModule):
|
|||
mconf = conf.model_config
|
||||
use_tril_attention_mask = conf.use_tril_attention_mask
|
||||
super().__init__()
|
||||
self.automatic_optimization = False
|
||||
self.save_hyperparameters()
|
||||
if mconf == None:
|
||||
mconf = ModelConfig()
|
||||
|
@ -205,6 +206,19 @@ class LightModule(pl.LightningModule):
|
|||
if self.use_tril_attention_mask:
|
||||
batch["attention_mask"] = self.get_batch_tril_matrix(block_size, batch_size=batch_size).to(self.device)
|
||||
outputs, loss = self.llm(**batch)
|
||||
|
||||
# optimizer = self.optimizers()
|
||||
# for o in optimizer:
|
||||
# o.zero_grad()
|
||||
# self.manual_backward(loss)
|
||||
# for o in optimizer:
|
||||
# o.step()
|
||||
|
||||
optimizer = self.optimizers()
|
||||
optimizer.zero_grad()
|
||||
self.manual_backward(loss)
|
||||
optimizer.step()
|
||||
|
||||
self.log("train_loss", loss, rank_zero_only=True)
|
||||
return loss
|
||||
|
||||
|
@ -232,6 +246,24 @@ class LightModule(pl.LightningModule):
|
|||
optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=self.learning_rate)
|
||||
return optimizer
|
||||
|
||||
# p = self.trainer.model.named_parameters()
|
||||
# p0 = []
|
||||
# p1 = []
|
||||
# p2 = []
|
||||
# for name, param in p:
|
||||
# if "h.0" in name or ".wte" in name:
|
||||
# p0.append(param)
|
||||
# continue
|
||||
# if "h.1" in name:
|
||||
# p1.append(param)
|
||||
# continue
|
||||
# p2.append(param)
|
||||
|
||||
# optimizer0 = torch.optim.AdamW(p0, lr=self.learning_rate / 1.5)
|
||||
# optimizer1 = torch.optim.AdamW(p1, lr=self.learning_rate)
|
||||
# optimizer2 = torch.optim.AdamW(p2, lr=self.learning_rate * 1.5)
|
||||
# return [optimizer0, optimizer1, optimizer2]
|
||||
|
||||
def configure_callbacks(self):
|
||||
checkpoint_callback = pl.callbacks.ModelCheckpoint(
|
||||
monitor="accuracy",
|
||||
|
|
Loading…
Reference in New Issue