set use local dataset.
This commit is contained in:
parent
087366c59b
commit
ac61c4d925
|
@ -8,9 +8,8 @@ import torchmetrics
|
||||||
from utils import init_model
|
from utils import init_model
|
||||||
from custom_models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
from custom_models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
||||||
|
|
||||||
from transformers import (
|
from transformers import AutoConfig
|
||||||
AutoConfig
|
|
||||||
)
|
|
||||||
|
|
||||||
class LitModule(pl.LightningModule):
|
class LitModule(pl.LightningModule):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -22,7 +21,7 @@ class LitModule(pl.LightningModule):
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.save_hyperparameters()
|
self.save_hyperparameters()
|
||||||
if path != "" :
|
if path != "":
|
||||||
config = AutoConfig.for_model(model_type=model_name)
|
config = AutoConfig.for_model(model_type=model_name)
|
||||||
model = GPT2LMHeadModel(config)
|
model = GPT2LMHeadModel(config)
|
||||||
model = model.from_pretrained(path)
|
model = model.from_pretrained(path)
|
||||||
|
@ -33,7 +32,7 @@ class LitModule(pl.LightningModule):
|
||||||
self.use_tril_attention_mask = use_tril_attention_mask
|
self.use_tril_attention_mask = use_tril_attention_mask
|
||||||
self.metric_loss = torchmetrics.MeanMetric()
|
self.metric_loss = torchmetrics.MeanMetric()
|
||||||
self.metric_accuracy = torchmetrics.Accuracy(
|
self.metric_accuracy = torchmetrics.Accuracy(
|
||||||
task='multiclass',
|
task="multiclass",
|
||||||
num_classes=self.llm.config.vocab_size,
|
num_classes=self.llm.config.vocab_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -45,17 +44,17 @@ class LitModule(pl.LightningModule):
|
||||||
return matrix
|
return matrix
|
||||||
|
|
||||||
def register_core_module(self, module: torch.nn.Module) -> torch.nn.Module:
|
def register_core_module(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||||
object.__setattr__(self, '__core_module__', module)
|
object.__setattr__(self, "__core_module__", module)
|
||||||
return module
|
return module
|
||||||
|
|
||||||
def training_step(self, batch: Dict[str, torch.Tensor], batch_idx):
|
def training_step(self, batch: Dict[str, torch.Tensor], batch_idx):
|
||||||
batch_size, block_size = batch['input_ids'].shape
|
batch_size, block_size = batch["input_ids"].shape
|
||||||
if self.use_tril_attention_mask:
|
if self.use_tril_attention_mask:
|
||||||
batch['attention_mask'] = self.get_batch_tril_matrix(block_size, batch_size=batch_size).to(self.device)
|
batch["attention_mask"] = self.get_batch_tril_matrix(block_size, batch_size=batch_size).to(self.device)
|
||||||
outputs = self.llm(**batch, return_dict=True)
|
outputs = self.llm(**batch, return_dict=True)
|
||||||
loss = outputs.loss
|
loss = outputs.loss
|
||||||
|
|
||||||
self.log('train_loss', loss, rank_zero_only=True)
|
self.log("train_loss", loss, rank_zero_only=True)
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
@ -63,7 +62,7 @@ class LitModule(pl.LightningModule):
|
||||||
outputs = self.llm(**batch, return_dict=True)
|
outputs = self.llm(**batch, return_dict=True)
|
||||||
loss = outputs.loss
|
loss = outputs.loss
|
||||||
logits = outputs.logits[..., :-1, :]
|
logits = outputs.logits[..., :-1, :]
|
||||||
labels = batch['labels'][..., 1:]
|
labels = batch["labels"][..., 1:]
|
||||||
|
|
||||||
self.metric_loss.update(loss)
|
self.metric_loss.update(loss)
|
||||||
|
|
||||||
|
@ -71,8 +70,8 @@ class LitModule(pl.LightningModule):
|
||||||
self.metric_accuracy.update(logits[label_mask], labels[label_mask])
|
self.metric_accuracy.update(logits[label_mask], labels[label_mask])
|
||||||
|
|
||||||
def on_validation_epoch_end(self) -> None:
|
def on_validation_epoch_end(self) -> None:
|
||||||
self.log('val_loss', self.metric_loss, rank_zero_only=True)
|
self.log("val_loss", self.metric_loss, rank_zero_only=True)
|
||||||
self.log('accuracy', self.metric_accuracy, rank_zero_only=True)
|
self.log("accuracy", self.metric_accuracy, rank_zero_only=True)
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
strategy = self.trainer.strategy
|
strategy = self.trainer.strategy
|
||||||
|
@ -92,15 +91,15 @@ class LitModule(pl.LightningModule):
|
||||||
|
|
||||||
def configure_callbacks(self):
|
def configure_callbacks(self):
|
||||||
checkpoint_callback = pl.callbacks.ModelCheckpoint(
|
checkpoint_callback = pl.callbacks.ModelCheckpoint(
|
||||||
monitor='accuracy',
|
monitor="accuracy",
|
||||||
mode='max',
|
mode="max",
|
||||||
filename='{epoch:02d}-{accuracy:.4f}',
|
filename="{epoch:02d}-{accuracy:.4f}",
|
||||||
)
|
)
|
||||||
early_stop_callback = pl.callbacks.EarlyStopping(
|
early_stop_callback = pl.callbacks.EarlyStopping(
|
||||||
monitor='accuracy',
|
monitor="accuracy",
|
||||||
min_delta=0.001,
|
min_delta=0.001,
|
||||||
patience=3,
|
patience=3,
|
||||||
mode='max',
|
mode="max",
|
||||||
stopping_threshold=1,
|
stopping_threshold=1,
|
||||||
)
|
)
|
||||||
return [checkpoint_callback, early_stop_callback]
|
return [checkpoint_callback, early_stop_callback]
|
||||||
|
|
|
@ -179,7 +179,7 @@ if __name__ == "__main__":
|
||||||
set_seed(args.seed)
|
set_seed(args.seed)
|
||||||
|
|
||||||
# lightning module
|
# lightning module
|
||||||
lit_module = LitModule(args.model_name,"./custom_models/gpt2", args.learning_rate, args.use_tril_attention_mask)
|
lit_module = LitModule(args.model_name, "./custom_models/gpt2", args.learning_rate, args.use_tril_attention_mask)
|
||||||
|
|
||||||
# datasets
|
# datasets
|
||||||
tokenizer = load_tokenizer(args.tokenizer_name_or_path)
|
tokenizer = load_tokenizer(args.tokenizer_name_or_path)
|
||||||
|
@ -187,9 +187,7 @@ if __name__ == "__main__":
|
||||||
val_dataset_list = []
|
val_dataset_list = []
|
||||||
for dataset_name in args.dataset_name:
|
for dataset_name in args.dataset_name:
|
||||||
dataset_args = dataset_name.split(":")
|
dataset_args = dataset_name.split(":")
|
||||||
raw_dataset = datasets.load_dataset(
|
raw_dataset = datasets.load_dataset("json", data_files="./dataset/58.jsonl.gz")
|
||||||
"json", data_files="/home/colin/develop/dataset/liwu/MNBVC/wiki/20230197/0.jsonl.gz"
|
|
||||||
)
|
|
||||||
# raw_dataset = datasets.load_dataset(*dataset_args)
|
# raw_dataset = datasets.load_dataset(*dataset_args)
|
||||||
train_dataset, val_dataset = split_raw_dataset(raw_dataset)
|
train_dataset, val_dataset = split_raw_dataset(raw_dataset)
|
||||||
train_dataset = process_dataset(train_dataset, tokenizer)
|
train_dataset = process_dataset(train_dataset, tokenizer)
|
||||||
|
|
Loading…
Reference in New Issue