[code] refactor

This commit is contained in:
Yiqing-Zhou 2023-05-07 13:01:02 +08:00
parent 939be31c10
commit 09507449f7
5 changed files with 132 additions and 128 deletions

3
.gitignore vendored
View File

@ -1,2 +1,5 @@
__pycache__
.ipynb_checkpoints
exports
lightning_logs
.env

View File

@ -1,34 +1,10 @@
import argparse
import os
from typing import List, Union
from typing import List
import torch
from transformers import (
AutoModel,
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizer,
)
from transformers import PreTrainedModel, PreTrainedTokenizer
def load_model(model_name_or_path: Union[str, os.PathLike]) -> PreTrainedModel:
try:
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path, trust_remote_code=True
)
except ValueError:
model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True)
return model
def load_tokenizer(model_name_or_path: Union[str, os.PathLike]) -> PreTrainedTokenizer:
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path, padding_side='left', trust_remote_code=True
)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
from utils import load_model, load_tokenizer
def eval_prompts(

81
lit_module.py Normal file
View File

@ -0,0 +1,81 @@
from functools import cache
from typing import Dict, Optional
import pytorch_lightning as pl
import torch
import torchmetrics
from utils import init_model
class LitModule(pl.LightningModule):
def __init__(self, model_name: str, use_tril_attention_mask: str = False):
super().__init__()
self.save_hyperparameters()
self.llm = self.register_core_module(init_model(model_name))
self.use_tril_attention_mask = use_tril_attention_mask
self.metric_loss = torchmetrics.MeanMetric()
self.metric_accuracy = torchmetrics.Accuracy(
task='multiclass',
num_classes=self.llm.config.vocab_size,
)
@cache
def get_batch_tril_matrix(
self, block_size: int, batch_size: Optional[int] = None
) -> torch.Tensor:
matrix = torch.ones(block_size, block_size).tril()
if batch_size is not None:
matrix = matrix.repeat(batch_size, 1, 1)
return matrix
def register_core_module(self, module: torch.nn.Module) -> torch.nn.Module:
object.__setattr__(self, '__core_module__', module)
return module
def training_step(self, batch: Dict[str, torch.Tensor], batch_idx):
batch_size, block_size = batch['input_ids'].shape
if self.use_tril_attention_mask:
batch['attention_mask'] = self.get_batch_tril_matrix(
block_size, batch_size=batch_size
).to(self.device)
outputs = self.llm(**batch, return_dict=True)
loss = outputs.loss
self.log('train_loss', loss, rank_zero_only=True)
return loss
def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx):
outputs = self.llm(**batch, return_dict=True)
loss = outputs.loss
logits = outputs.logits[..., :-1, :]
labels = batch['labels'][..., 1:]
self.metric_loss.update(loss)
label_mask = labels != -100
self.metric_accuracy.update(logits[label_mask], labels[label_mask])
def on_validation_epoch_end(self) -> None:
self.log('val_loss', self.metric_loss, rank_zero_only=True)
self.log('accuracy', self.metric_accuracy, rank_zero_only=True)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=0.0001)
return optimizer
def configure_callbacks(self):
checkpoint_callback = pl.callbacks.ModelCheckpoint(
monitor='accuracy',
mode='max',
filename='{epoch:02d}-{accuracy:.4f}',
)
early_stop_callback = pl.callbacks.EarlyStopping(
monitor='accuracy',
min_delta=0.001,
patience=3,
mode='max',
stopping_threshold=1,
)
return [checkpoint_callback, early_stop_callback]

View File

@ -1,45 +1,21 @@
import argparse
import os
from functools import cache, partial
from functools import partial
from itertools import chain
from typing import Dict, Optional, Tuple, Union
from typing import Dict, Tuple
import datasets
import pytorch_lightning as pl
import torch
import torchmetrics
from torch.utils.data import ConcatDataset, DataLoader
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
AutoTokenizer,
BatchEncoding,
DefaultDataCollator,
PreTrainedModel,
PreTrainedTokenizer,
set_seed,
)
def init_model(model_name: Union[str, os.PathLike]) -> PreTrainedModel:
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
try:
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
except ValueError:
model = AutoModel.from_config(config, trust_remote_code=True)
return model
def load_tokenizer(
tokenizer_name_or_path: Union[str, os.PathLike]
) -> PreTrainedTokenizer:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name_or_path, padding_side='left', trust_remote_code=True
)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
from lit_module import LitModule
from utils import load_tokenizer
def split_raw_dataset(
@ -169,79 +145,6 @@ def parse_args():
return args
class LitModule(pl.LightningModule):
def __init__(self, model_name: str, use_tril_attention_mask: str = False):
super().__init__()
self.save_hyperparameters()
self.llm = self.register_core_module(init_model(model_name))
self.use_tril_attention_mask = use_tril_attention_mask
self.metric_loss = torchmetrics.MeanMetric()
self.metric_accuracy = torchmetrics.Accuracy(
task='multiclass',
num_classes=self.llm.config.vocab_size,
)
@cache
def get_batch_tril_matrix(
self, block_size: int, batch_size: Optional[int] = None
) -> torch.Tensor:
matrix = torch.ones(block_size, block_size).tril()
if batch_size is not None:
matrix = matrix.repeat(batch_size, 1, 1)
return matrix
def register_core_module(self, module: torch.nn.Module) -> torch.nn.Module:
object.__setattr__(self, '__core_module__', module)
return module
def training_step(self, batch: Dict[str, torch.Tensor], batch_idx):
batch_size, block_size = batch['input_ids'].shape
if self.use_tril_attention_mask:
batch['attention_mask'] = self.get_batch_tril_matrix(
block_size, batch_size=batch_size
).to(self.device)
outputs = self.llm(**batch, return_dict=True)
loss = outputs.loss
self.log('train_loss', loss, rank_zero_only=True)
return loss
def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx):
outputs = self.llm(**batch, return_dict=True)
loss = outputs.loss
logits = outputs.logits[..., :-1, :]
labels = batch['labels'][..., 1:]
self.metric_loss.update(loss)
label_mask = labels != -100
self.metric_accuracy.update(logits[label_mask], labels[label_mask])
def on_validation_epoch_end(self) -> None:
self.log('val_loss', self.metric_loss, rank_zero_only=True)
self.log('accuracy', self.metric_accuracy, rank_zero_only=True)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=0.0001)
return optimizer
def configure_callbacks(self):
checkpoint_callback = pl.callbacks.ModelCheckpoint(
monitor='accuracy',
mode='max',
filename='{epoch:02d}-{accuracy:.4f}',
)
early_stop_callback = pl.callbacks.EarlyStopping(
monitor='accuracy',
min_delta=0.001,
patience=3,
mode='max',
stopping_threshold=1,
)
return [checkpoint_callback, early_stop_callback]
if __name__ == '__main__':
args = parse_args()

41
utils.py Normal file
View File

@ -0,0 +1,41 @@
import os
from typing import Union
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizer,
)
def init_model(model_name: Union[str, os.PathLike]) -> PreTrainedModel:
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
try:
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
except ValueError:
model = AutoModel.from_config(config, trust_remote_code=True)
return model
def load_model(model_name_or_path: Union[str, os.PathLike]) -> PreTrainedModel:
try:
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path, trust_remote_code=True
)
except ValueError:
model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True)
return model
def load_tokenizer(
tokenizer_name_or_path: Union[str, os.PathLike]
) -> PreTrainedTokenizer:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name_or_path, padding_side='left', trust_remote_code=True
)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
return tokenizer