diff --git a/lit_module.py b/lit_module.py index c5c6279..8db8a46 100644 --- a/lit_module.py +++ b/lit_module.py @@ -8,9 +8,8 @@ import torchmetrics from utils import init_model from custom_models.gpt2.modeling_gpt2 import GPT2LMHeadModel -from transformers import ( - AutoConfig -) +from transformers import AutoConfig + class LitModule(pl.LightningModule): def __init__( @@ -22,7 +21,7 @@ class LitModule(pl.LightningModule): ): super().__init__() self.save_hyperparameters() - if path != "" : + if path != "": config = AutoConfig.for_model(model_type=model_name) model = GPT2LMHeadModel(config) model = model.from_pretrained(path) @@ -33,7 +32,7 @@ class LitModule(pl.LightningModule): self.use_tril_attention_mask = use_tril_attention_mask self.metric_loss = torchmetrics.MeanMetric() self.metric_accuracy = torchmetrics.Accuracy( - task='multiclass', + task="multiclass", num_classes=self.llm.config.vocab_size, ) @@ -45,17 +44,17 @@ class LitModule(pl.LightningModule): return matrix def register_core_module(self, module: torch.nn.Module) -> torch.nn.Module: - object.__setattr__(self, '__core_module__', module) + object.__setattr__(self, "__core_module__", module) return module def training_step(self, batch: Dict[str, torch.Tensor], batch_idx): - batch_size, block_size = batch['input_ids'].shape + batch_size, block_size = batch["input_ids"].shape if self.use_tril_attention_mask: - batch['attention_mask'] = self.get_batch_tril_matrix(block_size, batch_size=batch_size).to(self.device) + batch["attention_mask"] = self.get_batch_tril_matrix(block_size, batch_size=batch_size).to(self.device) outputs = self.llm(**batch, return_dict=True) loss = outputs.loss - self.log('train_loss', loss, rank_zero_only=True) + self.log("train_loss", loss, rank_zero_only=True) return loss @@ -63,7 +62,7 @@ class LitModule(pl.LightningModule): outputs = self.llm(**batch, return_dict=True) loss = outputs.loss logits = outputs.logits[..., :-1, :] - labels = batch['labels'][..., 1:] + labels = batch["labels"][..., 1:] self.metric_loss.update(loss) @@ -71,8 +70,8 @@ class LitModule(pl.LightningModule): self.metric_accuracy.update(logits[label_mask], labels[label_mask]) def on_validation_epoch_end(self) -> None: - self.log('val_loss', self.metric_loss, rank_zero_only=True) - self.log('accuracy', self.metric_accuracy, rank_zero_only=True) + self.log("val_loss", self.metric_loss, rank_zero_only=True) + self.log("accuracy", self.metric_accuracy, rank_zero_only=True) def configure_optimizers(self): strategy = self.trainer.strategy @@ -92,15 +91,15 @@ class LitModule(pl.LightningModule): def configure_callbacks(self): checkpoint_callback = pl.callbacks.ModelCheckpoint( - monitor='accuracy', - mode='max', - filename='{epoch:02d}-{accuracy:.4f}', + monitor="accuracy", + mode="max", + filename="{epoch:02d}-{accuracy:.4f}", ) early_stop_callback = pl.callbacks.EarlyStopping( - monitor='accuracy', + monitor="accuracy", min_delta=0.001, patience=3, - mode='max', + mode="max", stopping_threshold=1, ) return [checkpoint_callback, early_stop_callback] diff --git a/lit_train.py b/lit_train.py index 3360e24..c7e9873 100644 --- a/lit_train.py +++ b/lit_train.py @@ -179,7 +179,7 @@ if __name__ == "__main__": set_seed(args.seed) # lightning module - lit_module = LitModule(args.model_name,"./custom_models/gpt2", args.learning_rate, args.use_tril_attention_mask) + lit_module = LitModule(args.model_name, "./custom_models/gpt2", args.learning_rate, args.use_tril_attention_mask) # datasets tokenizer = load_tokenizer(args.tokenizer_name_or_path) @@ -187,9 +187,7 @@ if __name__ == "__main__": val_dataset_list = [] for dataset_name in args.dataset_name: dataset_args = dataset_name.split(":") - raw_dataset = datasets.load_dataset( - "json", data_files="/home/colin/develop/dataset/liwu/MNBVC/wiki/20230197/0.jsonl.gz" - ) + raw_dataset = datasets.load_dataset("json", data_files="./dataset/58.jsonl.gz") # raw_dataset = datasets.load_dataset(*dataset_args) train_dataset, val_dataset = split_raw_dataset(raw_dataset) train_dataset = process_dataset(train_dataset, tokenizer)