diff --git a/wit/model/light_module.py b/wit/model/light_module.py index e766746..767f4be 100644 --- a/wit/model/light_module.py +++ b/wit/model/light_module.py @@ -172,6 +172,7 @@ class LightModule(pl.LightningModule): mconf = conf.model_config use_tril_attention_mask = conf.use_tril_attention_mask super().__init__() + self.automatic_optimization = False self.save_hyperparameters() if mconf == None: mconf = ModelConfig() @@ -205,6 +206,19 @@ class LightModule(pl.LightningModule): if self.use_tril_attention_mask: batch["attention_mask"] = self.get_batch_tril_matrix(block_size, batch_size=batch_size).to(self.device) outputs, loss = self.llm(**batch) + + # optimizer = self.optimizers() + # for o in optimizer: + # o.zero_grad() + # self.manual_backward(loss) + # for o in optimizer: + # o.step() + + optimizer = self.optimizers() + optimizer.zero_grad() + self.manual_backward(loss) + optimizer.step() + self.log("train_loss", loss, rank_zero_only=True) return loss @@ -232,6 +246,24 @@ class LightModule(pl.LightningModule): optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=self.learning_rate) return optimizer + # p = self.trainer.model.named_parameters() + # p0 = [] + # p1 = [] + # p2 = [] + # for name, param in p: + # if "h.0" in name or ".wte" in name: + # p0.append(param) + # continue + # if "h.1" in name: + # p1.append(param) + # continue + # p2.append(param) + + # optimizer0 = torch.optim.AdamW(p0, lr=self.learning_rate / 1.5) + # optimizer1 = torch.optim.AdamW(p1, lr=self.learning_rate) + # optimizer2 = torch.optim.AdamW(p2, lr=self.learning_rate * 1.5) + # return [optimizer0, optimizer1, optimizer2] + def configure_callbacks(self): checkpoint_callback = pl.callbacks.ModelCheckpoint( monitor="accuracy",