set use local dataset.
This commit is contained in:
		
							parent
							
								
									087366c59b
								
							
						
					
					
						commit
						ac61c4d925
					
				| 
						 | 
				
			
			@ -8,9 +8,8 @@ import torchmetrics
 | 
			
		|||
from utils import init_model
 | 
			
		||||
from custom_models.gpt2.modeling_gpt2 import GPT2LMHeadModel
 | 
			
		||||
 | 
			
		||||
from transformers import (
 | 
			
		||||
    AutoConfig
 | 
			
		||||
)
 | 
			
		||||
from transformers import AutoConfig
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LitModule(pl.LightningModule):
 | 
			
		||||
    def __init__(
 | 
			
		||||
| 
						 | 
				
			
			@ -22,7 +21,7 @@ class LitModule(pl.LightningModule):
 | 
			
		|||
    ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.save_hyperparameters()
 | 
			
		||||
        if path != "" :
 | 
			
		||||
        if path != "":
 | 
			
		||||
            config = AutoConfig.for_model(model_type=model_name)
 | 
			
		||||
            model = GPT2LMHeadModel(config)
 | 
			
		||||
            model = model.from_pretrained(path)
 | 
			
		||||
| 
						 | 
				
			
			@ -33,7 +32,7 @@ class LitModule(pl.LightningModule):
 | 
			
		|||
        self.use_tril_attention_mask = use_tril_attention_mask
 | 
			
		||||
        self.metric_loss = torchmetrics.MeanMetric()
 | 
			
		||||
        self.metric_accuracy = torchmetrics.Accuracy(
 | 
			
		||||
            task='multiclass',
 | 
			
		||||
            task="multiclass",
 | 
			
		||||
            num_classes=self.llm.config.vocab_size,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -45,17 +44,17 @@ class LitModule(pl.LightningModule):
 | 
			
		|||
        return matrix
 | 
			
		||||
 | 
			
		||||
    def register_core_module(self, module: torch.nn.Module) -> torch.nn.Module:
 | 
			
		||||
        object.__setattr__(self, '__core_module__', module)
 | 
			
		||||
        object.__setattr__(self, "__core_module__", module)
 | 
			
		||||
        return module
 | 
			
		||||
 | 
			
		||||
    def training_step(self, batch: Dict[str, torch.Tensor], batch_idx):
 | 
			
		||||
        batch_size, block_size = batch['input_ids'].shape
 | 
			
		||||
        batch_size, block_size = batch["input_ids"].shape
 | 
			
		||||
        if self.use_tril_attention_mask:
 | 
			
		||||
            batch['attention_mask'] = self.get_batch_tril_matrix(block_size, batch_size=batch_size).to(self.device)
 | 
			
		||||
            batch["attention_mask"] = self.get_batch_tril_matrix(block_size, batch_size=batch_size).to(self.device)
 | 
			
		||||
        outputs = self.llm(**batch, return_dict=True)
 | 
			
		||||
        loss = outputs.loss
 | 
			
		||||
 | 
			
		||||
        self.log('train_loss', loss, rank_zero_only=True)
 | 
			
		||||
        self.log("train_loss", loss, rank_zero_only=True)
 | 
			
		||||
 | 
			
		||||
        return loss
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -63,7 +62,7 @@ class LitModule(pl.LightningModule):
 | 
			
		|||
        outputs = self.llm(**batch, return_dict=True)
 | 
			
		||||
        loss = outputs.loss
 | 
			
		||||
        logits = outputs.logits[..., :-1, :]
 | 
			
		||||
        labels = batch['labels'][..., 1:]
 | 
			
		||||
        labels = batch["labels"][..., 1:]
 | 
			
		||||
 | 
			
		||||
        self.metric_loss.update(loss)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -71,8 +70,8 @@ class LitModule(pl.LightningModule):
 | 
			
		|||
        self.metric_accuracy.update(logits[label_mask], labels[label_mask])
 | 
			
		||||
 | 
			
		||||
    def on_validation_epoch_end(self) -> None:
 | 
			
		||||
        self.log('val_loss', self.metric_loss, rank_zero_only=True)
 | 
			
		||||
        self.log('accuracy', self.metric_accuracy, rank_zero_only=True)
 | 
			
		||||
        self.log("val_loss", self.metric_loss, rank_zero_only=True)
 | 
			
		||||
        self.log("accuracy", self.metric_accuracy, rank_zero_only=True)
 | 
			
		||||
 | 
			
		||||
    def configure_optimizers(self):
 | 
			
		||||
        strategy = self.trainer.strategy
 | 
			
		||||
| 
						 | 
				
			
			@ -92,15 +91,15 @@ class LitModule(pl.LightningModule):
 | 
			
		|||
 | 
			
		||||
    def configure_callbacks(self):
 | 
			
		||||
        checkpoint_callback = pl.callbacks.ModelCheckpoint(
 | 
			
		||||
            monitor='accuracy',
 | 
			
		||||
            mode='max',
 | 
			
		||||
            filename='{epoch:02d}-{accuracy:.4f}',
 | 
			
		||||
            monitor="accuracy",
 | 
			
		||||
            mode="max",
 | 
			
		||||
            filename="{epoch:02d}-{accuracy:.4f}",
 | 
			
		||||
        )
 | 
			
		||||
        early_stop_callback = pl.callbacks.EarlyStopping(
 | 
			
		||||
            monitor='accuracy',
 | 
			
		||||
            monitor="accuracy",
 | 
			
		||||
            min_delta=0.001,
 | 
			
		||||
            patience=3,
 | 
			
		||||
            mode='max',
 | 
			
		||||
            mode="max",
 | 
			
		||||
            stopping_threshold=1,
 | 
			
		||||
        )
 | 
			
		||||
        return [checkpoint_callback, early_stop_callback]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -179,7 +179,7 @@ if __name__ == "__main__":
 | 
			
		|||
    set_seed(args.seed)
 | 
			
		||||
 | 
			
		||||
    # lightning module
 | 
			
		||||
    lit_module = LitModule(args.model_name,"./custom_models/gpt2", args.learning_rate, args.use_tril_attention_mask)
 | 
			
		||||
    lit_module = LitModule(args.model_name, "./custom_models/gpt2", args.learning_rate, args.use_tril_attention_mask)
 | 
			
		||||
 | 
			
		||||
    # datasets
 | 
			
		||||
    tokenizer = load_tokenizer(args.tokenizer_name_or_path)
 | 
			
		||||
| 
						 | 
				
			
			@ -187,9 +187,7 @@ if __name__ == "__main__":
 | 
			
		|||
    val_dataset_list = []
 | 
			
		||||
    for dataset_name in args.dataset_name:
 | 
			
		||||
        dataset_args = dataset_name.split(":")
 | 
			
		||||
        raw_dataset = datasets.load_dataset(
 | 
			
		||||
            "json", data_files="/home/colin/develop/dataset/liwu/MNBVC/wiki/20230197/0.jsonl.gz"
 | 
			
		||||
        )
 | 
			
		||||
        raw_dataset = datasets.load_dataset("json", data_files="./dataset/58.jsonl.gz")
 | 
			
		||||
        # raw_dataset = datasets.load_dataset(*dataset_args)
 | 
			
		||||
        train_dataset, val_dataset = split_raw_dataset(raw_dataset)
 | 
			
		||||
        train_dataset = process_dataset(train_dataset, tokenizer)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue