set use local dataset.

This commit is contained in:
Colin 2024-02-24 13:44:22 +08:00
parent 087366c59b
commit ac61c4d925
2 changed files with 18 additions and 21 deletions

View File

@ -8,9 +8,8 @@ import torchmetrics
from utils import init_model from utils import init_model
from custom_models.gpt2.modeling_gpt2 import GPT2LMHeadModel from custom_models.gpt2.modeling_gpt2 import GPT2LMHeadModel
from transformers import ( from transformers import AutoConfig
AutoConfig
)
class LitModule(pl.LightningModule): class LitModule(pl.LightningModule):
def __init__( def __init__(
@ -22,7 +21,7 @@ class LitModule(pl.LightningModule):
): ):
super().__init__() super().__init__()
self.save_hyperparameters() self.save_hyperparameters()
if path != "" : if path != "":
config = AutoConfig.for_model(model_type=model_name) config = AutoConfig.for_model(model_type=model_name)
model = GPT2LMHeadModel(config) model = GPT2LMHeadModel(config)
model = model.from_pretrained(path) model = model.from_pretrained(path)
@ -33,7 +32,7 @@ class LitModule(pl.LightningModule):
self.use_tril_attention_mask = use_tril_attention_mask self.use_tril_attention_mask = use_tril_attention_mask
self.metric_loss = torchmetrics.MeanMetric() self.metric_loss = torchmetrics.MeanMetric()
self.metric_accuracy = torchmetrics.Accuracy( self.metric_accuracy = torchmetrics.Accuracy(
task='multiclass', task="multiclass",
num_classes=self.llm.config.vocab_size, num_classes=self.llm.config.vocab_size,
) )
@ -45,17 +44,17 @@ class LitModule(pl.LightningModule):
return matrix return matrix
def register_core_module(self, module: torch.nn.Module) -> torch.nn.Module: def register_core_module(self, module: torch.nn.Module) -> torch.nn.Module:
object.__setattr__(self, '__core_module__', module) object.__setattr__(self, "__core_module__", module)
return module return module
def training_step(self, batch: Dict[str, torch.Tensor], batch_idx): def training_step(self, batch: Dict[str, torch.Tensor], batch_idx):
batch_size, block_size = batch['input_ids'].shape batch_size, block_size = batch["input_ids"].shape
if self.use_tril_attention_mask: if self.use_tril_attention_mask:
batch['attention_mask'] = self.get_batch_tril_matrix(block_size, batch_size=batch_size).to(self.device) batch["attention_mask"] = self.get_batch_tril_matrix(block_size, batch_size=batch_size).to(self.device)
outputs = self.llm(**batch, return_dict=True) outputs = self.llm(**batch, return_dict=True)
loss = outputs.loss loss = outputs.loss
self.log('train_loss', loss, rank_zero_only=True) self.log("train_loss", loss, rank_zero_only=True)
return loss return loss
@ -63,7 +62,7 @@ class LitModule(pl.LightningModule):
outputs = self.llm(**batch, return_dict=True) outputs = self.llm(**batch, return_dict=True)
loss = outputs.loss loss = outputs.loss
logits = outputs.logits[..., :-1, :] logits = outputs.logits[..., :-1, :]
labels = batch['labels'][..., 1:] labels = batch["labels"][..., 1:]
self.metric_loss.update(loss) self.metric_loss.update(loss)
@ -71,8 +70,8 @@ class LitModule(pl.LightningModule):
self.metric_accuracy.update(logits[label_mask], labels[label_mask]) self.metric_accuracy.update(logits[label_mask], labels[label_mask])
def on_validation_epoch_end(self) -> None: def on_validation_epoch_end(self) -> None:
self.log('val_loss', self.metric_loss, rank_zero_only=True) self.log("val_loss", self.metric_loss, rank_zero_only=True)
self.log('accuracy', self.metric_accuracy, rank_zero_only=True) self.log("accuracy", self.metric_accuracy, rank_zero_only=True)
def configure_optimizers(self): def configure_optimizers(self):
strategy = self.trainer.strategy strategy = self.trainer.strategy
@ -92,15 +91,15 @@ class LitModule(pl.LightningModule):
def configure_callbacks(self): def configure_callbacks(self):
checkpoint_callback = pl.callbacks.ModelCheckpoint( checkpoint_callback = pl.callbacks.ModelCheckpoint(
monitor='accuracy', monitor="accuracy",
mode='max', mode="max",
filename='{epoch:02d}-{accuracy:.4f}', filename="{epoch:02d}-{accuracy:.4f}",
) )
early_stop_callback = pl.callbacks.EarlyStopping( early_stop_callback = pl.callbacks.EarlyStopping(
monitor='accuracy', monitor="accuracy",
min_delta=0.001, min_delta=0.001,
patience=3, patience=3,
mode='max', mode="max",
stopping_threshold=1, stopping_threshold=1,
) )
return [checkpoint_callback, early_stop_callback] return [checkpoint_callback, early_stop_callback]

View File

@ -179,7 +179,7 @@ if __name__ == "__main__":
set_seed(args.seed) set_seed(args.seed)
# lightning module # lightning module
lit_module = LitModule(args.model_name,"./custom_models/gpt2", args.learning_rate, args.use_tril_attention_mask) lit_module = LitModule(args.model_name, "./custom_models/gpt2", args.learning_rate, args.use_tril_attention_mask)
# datasets # datasets
tokenizer = load_tokenizer(args.tokenizer_name_or_path) tokenizer = load_tokenizer(args.tokenizer_name_or_path)
@ -187,9 +187,7 @@ if __name__ == "__main__":
val_dataset_list = [] val_dataset_list = []
for dataset_name in args.dataset_name: for dataset_name in args.dataset_name:
dataset_args = dataset_name.split(":") dataset_args = dataset_name.split(":")
raw_dataset = datasets.load_dataset( raw_dataset = datasets.load_dataset("json", data_files="./dataset/58.jsonl.gz")
"json", data_files="/home/colin/develop/dataset/liwu/MNBVC/wiki/20230197/0.jsonl.gz"
)
# raw_dataset = datasets.load_dataset(*dataset_args) # raw_dataset = datasets.load_dataset(*dataset_args)
train_dataset, val_dataset = split_raw_dataset(raw_dataset) train_dataset, val_dataset = split_raw_dataset(raw_dataset)
train_dataset = process_dataset(train_dataset, tokenizer) train_dataset = process_dataset(train_dataset, tokenizer)