[code] refactor
This commit is contained in:
parent
939be31c10
commit
09507449f7
|
@ -1,2 +1,5 @@
|
|||
__pycache__
|
||||
.ipynb_checkpoints
|
||||
exports
|
||||
lightning_logs
|
||||
.env
|
||||
|
|
30
generate.py
30
generate.py
|
@ -1,34 +1,10 @@
|
|||
import argparse
|
||||
import os
|
||||
from typing import List, Union
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
AutoModel,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizer,
|
||||
)
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
|
||||
def load_model(model_name_or_path: Union[str, os.PathLike]) -> PreTrainedModel:
|
||||
try:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name_or_path, trust_remote_code=True
|
||||
)
|
||||
except ValueError:
|
||||
model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True)
|
||||
return model
|
||||
|
||||
|
||||
def load_tokenizer(model_name_or_path: Union[str, os.PathLike]) -> PreTrainedTokenizer:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name_or_path, padding_side='left', trust_remote_code=True
|
||||
)
|
||||
if tokenizer.pad_token_id is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
return tokenizer
|
||||
from utils import load_model, load_tokenizer
|
||||
|
||||
|
||||
def eval_prompts(
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
from functools import cache
|
||||
from typing import Dict, Optional
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchmetrics
|
||||
|
||||
from utils import init_model
|
||||
|
||||
|
||||
class LitModule(pl.LightningModule):
|
||||
def __init__(self, model_name: str, use_tril_attention_mask: str = False):
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
self.llm = self.register_core_module(init_model(model_name))
|
||||
self.use_tril_attention_mask = use_tril_attention_mask
|
||||
self.metric_loss = torchmetrics.MeanMetric()
|
||||
self.metric_accuracy = torchmetrics.Accuracy(
|
||||
task='multiclass',
|
||||
num_classes=self.llm.config.vocab_size,
|
||||
)
|
||||
|
||||
@cache
|
||||
def get_batch_tril_matrix(
|
||||
self, block_size: int, batch_size: Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
matrix = torch.ones(block_size, block_size).tril()
|
||||
if batch_size is not None:
|
||||
matrix = matrix.repeat(batch_size, 1, 1)
|
||||
return matrix
|
||||
|
||||
def register_core_module(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
object.__setattr__(self, '__core_module__', module)
|
||||
return module
|
||||
|
||||
def training_step(self, batch: Dict[str, torch.Tensor], batch_idx):
|
||||
batch_size, block_size = batch['input_ids'].shape
|
||||
if self.use_tril_attention_mask:
|
||||
batch['attention_mask'] = self.get_batch_tril_matrix(
|
||||
block_size, batch_size=batch_size
|
||||
).to(self.device)
|
||||
outputs = self.llm(**batch, return_dict=True)
|
||||
loss = outputs.loss
|
||||
|
||||
self.log('train_loss', loss, rank_zero_only=True)
|
||||
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx):
|
||||
outputs = self.llm(**batch, return_dict=True)
|
||||
loss = outputs.loss
|
||||
logits = outputs.logits[..., :-1, :]
|
||||
labels = batch['labels'][..., 1:]
|
||||
|
||||
self.metric_loss.update(loss)
|
||||
|
||||
label_mask = labels != -100
|
||||
self.metric_accuracy.update(logits[label_mask], labels[label_mask])
|
||||
|
||||
def on_validation_epoch_end(self) -> None:
|
||||
self.log('val_loss', self.metric_loss, rank_zero_only=True)
|
||||
self.log('accuracy', self.metric_accuracy, rank_zero_only=True)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=0.0001)
|
||||
return optimizer
|
||||
|
||||
def configure_callbacks(self):
|
||||
checkpoint_callback = pl.callbacks.ModelCheckpoint(
|
||||
monitor='accuracy',
|
||||
mode='max',
|
||||
filename='{epoch:02d}-{accuracy:.4f}',
|
||||
)
|
||||
early_stop_callback = pl.callbacks.EarlyStopping(
|
||||
monitor='accuracy',
|
||||
min_delta=0.001,
|
||||
patience=3,
|
||||
mode='max',
|
||||
stopping_threshold=1,
|
||||
)
|
||||
return [checkpoint_callback, early_stop_callback]
|
|
@ -1,45 +1,21 @@
|
|||
import argparse
|
||||
import os
|
||||
from functools import cache, partial
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import datasets
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchmetrics
|
||||
from torch.utils.data import ConcatDataset, DataLoader
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BatchEncoding,
|
||||
DefaultDataCollator,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizer,
|
||||
set_seed,
|
||||
)
|
||||
|
||||
|
||||
def init_model(model_name: Union[str, os.PathLike]) -> PreTrainedModel:
|
||||
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
try:
|
||||
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
|
||||
except ValueError:
|
||||
model = AutoModel.from_config(config, trust_remote_code=True)
|
||||
return model
|
||||
|
||||
|
||||
def load_tokenizer(
|
||||
tokenizer_name_or_path: Union[str, os.PathLike]
|
||||
) -> PreTrainedTokenizer:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_name_or_path, padding_side='left', trust_remote_code=True
|
||||
)
|
||||
if tokenizer.pad_token_id is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
return tokenizer
|
||||
from lit_module import LitModule
|
||||
from utils import load_tokenizer
|
||||
|
||||
|
||||
def split_raw_dataset(
|
||||
|
@ -169,79 +145,6 @@ def parse_args():
|
|||
return args
|
||||
|
||||
|
||||
class LitModule(pl.LightningModule):
|
||||
def __init__(self, model_name: str, use_tril_attention_mask: str = False):
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
self.llm = self.register_core_module(init_model(model_name))
|
||||
self.use_tril_attention_mask = use_tril_attention_mask
|
||||
self.metric_loss = torchmetrics.MeanMetric()
|
||||
self.metric_accuracy = torchmetrics.Accuracy(
|
||||
task='multiclass',
|
||||
num_classes=self.llm.config.vocab_size,
|
||||
)
|
||||
|
||||
@cache
|
||||
def get_batch_tril_matrix(
|
||||
self, block_size: int, batch_size: Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
matrix = torch.ones(block_size, block_size).tril()
|
||||
if batch_size is not None:
|
||||
matrix = matrix.repeat(batch_size, 1, 1)
|
||||
return matrix
|
||||
|
||||
def register_core_module(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
object.__setattr__(self, '__core_module__', module)
|
||||
return module
|
||||
|
||||
def training_step(self, batch: Dict[str, torch.Tensor], batch_idx):
|
||||
batch_size, block_size = batch['input_ids'].shape
|
||||
if self.use_tril_attention_mask:
|
||||
batch['attention_mask'] = self.get_batch_tril_matrix(
|
||||
block_size, batch_size=batch_size
|
||||
).to(self.device)
|
||||
outputs = self.llm(**batch, return_dict=True)
|
||||
loss = outputs.loss
|
||||
|
||||
self.log('train_loss', loss, rank_zero_only=True)
|
||||
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx):
|
||||
outputs = self.llm(**batch, return_dict=True)
|
||||
loss = outputs.loss
|
||||
logits = outputs.logits[..., :-1, :]
|
||||
labels = batch['labels'][..., 1:]
|
||||
|
||||
self.metric_loss.update(loss)
|
||||
|
||||
label_mask = labels != -100
|
||||
self.metric_accuracy.update(logits[label_mask], labels[label_mask])
|
||||
|
||||
def on_validation_epoch_end(self) -> None:
|
||||
self.log('val_loss', self.metric_loss, rank_zero_only=True)
|
||||
self.log('accuracy', self.metric_accuracy, rank_zero_only=True)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=0.0001)
|
||||
return optimizer
|
||||
|
||||
def configure_callbacks(self):
|
||||
checkpoint_callback = pl.callbacks.ModelCheckpoint(
|
||||
monitor='accuracy',
|
||||
mode='max',
|
||||
filename='{epoch:02d}-{accuracy:.4f}',
|
||||
)
|
||||
early_stop_callback = pl.callbacks.EarlyStopping(
|
||||
monitor='accuracy',
|
||||
min_delta=0.001,
|
||||
patience=3,
|
||||
mode='max',
|
||||
stopping_threshold=1,
|
||||
)
|
||||
return [checkpoint_callback, early_stop_callback]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
import os
|
||||
from typing import Union
|
||||
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizer,
|
||||
)
|
||||
|
||||
|
||||
def init_model(model_name: Union[str, os.PathLike]) -> PreTrainedModel:
|
||||
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
try:
|
||||
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
|
||||
except ValueError:
|
||||
model = AutoModel.from_config(config, trust_remote_code=True)
|
||||
return model
|
||||
|
||||
|
||||
def load_model(model_name_or_path: Union[str, os.PathLike]) -> PreTrainedModel:
|
||||
try:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name_or_path, trust_remote_code=True
|
||||
)
|
||||
except ValueError:
|
||||
model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True)
|
||||
return model
|
||||
|
||||
|
||||
def load_tokenizer(
|
||||
tokenizer_name_or_path: Union[str, os.PathLike]
|
||||
) -> PreTrainedTokenizer:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_name_or_path, padding_side='left', trust_remote_code=True
|
||||
)
|
||||
if tokenizer.pad_token_id is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
return tokenizer
|
Loading…
Reference in New Issue