[code] refactor

This commit is contained in:
Yiqing-Zhou 2023-05-07 13:01:02 +08:00
parent 939be31c10
commit 09507449f7
5 changed files with 132 additions and 128 deletions

3
.gitignore vendored
View File

@ -1,2 +1,5 @@
__pycache__
.ipynb_checkpoints
exports
lightning_logs lightning_logs
.env .env

View File

@ -1,34 +1,10 @@
import argparse import argparse
import os from typing import List
from typing import List, Union
import torch import torch
from transformers import ( from transformers import PreTrainedModel, PreTrainedTokenizer
AutoModel,
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizer,
)
from utils import load_model, load_tokenizer
def load_model(model_name_or_path: Union[str, os.PathLike]) -> PreTrainedModel:
try:
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path, trust_remote_code=True
)
except ValueError:
model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True)
return model
def load_tokenizer(model_name_or_path: Union[str, os.PathLike]) -> PreTrainedTokenizer:
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path, padding_side='left', trust_remote_code=True
)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
def eval_prompts( def eval_prompts(

81
lit_module.py Normal file
View File

@ -0,0 +1,81 @@
from functools import cache
from typing import Dict, Optional
import pytorch_lightning as pl
import torch
import torchmetrics
from utils import init_model
class LitModule(pl.LightningModule):
def __init__(self, model_name: str, use_tril_attention_mask: str = False):
super().__init__()
self.save_hyperparameters()
self.llm = self.register_core_module(init_model(model_name))
self.use_tril_attention_mask = use_tril_attention_mask
self.metric_loss = torchmetrics.MeanMetric()
self.metric_accuracy = torchmetrics.Accuracy(
task='multiclass',
num_classes=self.llm.config.vocab_size,
)
@cache
def get_batch_tril_matrix(
self, block_size: int, batch_size: Optional[int] = None
) -> torch.Tensor:
matrix = torch.ones(block_size, block_size).tril()
if batch_size is not None:
matrix = matrix.repeat(batch_size, 1, 1)
return matrix
def register_core_module(self, module: torch.nn.Module) -> torch.nn.Module:
object.__setattr__(self, '__core_module__', module)
return module
def training_step(self, batch: Dict[str, torch.Tensor], batch_idx):
batch_size, block_size = batch['input_ids'].shape
if self.use_tril_attention_mask:
batch['attention_mask'] = self.get_batch_tril_matrix(
block_size, batch_size=batch_size
).to(self.device)
outputs = self.llm(**batch, return_dict=True)
loss = outputs.loss
self.log('train_loss', loss, rank_zero_only=True)
return loss
def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx):
outputs = self.llm(**batch, return_dict=True)
loss = outputs.loss
logits = outputs.logits[..., :-1, :]
labels = batch['labels'][..., 1:]
self.metric_loss.update(loss)
label_mask = labels != -100
self.metric_accuracy.update(logits[label_mask], labels[label_mask])
def on_validation_epoch_end(self) -> None:
self.log('val_loss', self.metric_loss, rank_zero_only=True)
self.log('accuracy', self.metric_accuracy, rank_zero_only=True)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=0.0001)
return optimizer
def configure_callbacks(self):
checkpoint_callback = pl.callbacks.ModelCheckpoint(
monitor='accuracy',
mode='max',
filename='{epoch:02d}-{accuracy:.4f}',
)
early_stop_callback = pl.callbacks.EarlyStopping(
monitor='accuracy',
min_delta=0.001,
patience=3,
mode='max',
stopping_threshold=1,
)
return [checkpoint_callback, early_stop_callback]

View File

@ -1,45 +1,21 @@
import argparse import argparse
import os from functools import partial
from functools import cache, partial
from itertools import chain from itertools import chain
from typing import Dict, Optional, Tuple, Union from typing import Dict, Tuple
import datasets import datasets
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import torchmetrics
from torch.utils.data import ConcatDataset, DataLoader from torch.utils.data import ConcatDataset, DataLoader
from transformers import ( from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
AutoTokenizer,
BatchEncoding, BatchEncoding,
DefaultDataCollator, DefaultDataCollator,
PreTrainedModel,
PreTrainedTokenizer, PreTrainedTokenizer,
set_seed, set_seed,
) )
from lit_module import LitModule
def init_model(model_name: Union[str, os.PathLike]) -> PreTrainedModel: from utils import load_tokenizer
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
try:
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
except ValueError:
model = AutoModel.from_config(config, trust_remote_code=True)
return model
def load_tokenizer(
tokenizer_name_or_path: Union[str, os.PathLike]
) -> PreTrainedTokenizer:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name_or_path, padding_side='left', trust_remote_code=True
)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
def split_raw_dataset( def split_raw_dataset(
@ -169,79 +145,6 @@ def parse_args():
return args return args
class LitModule(pl.LightningModule):
def __init__(self, model_name: str, use_tril_attention_mask: str = False):
super().__init__()
self.save_hyperparameters()
self.llm = self.register_core_module(init_model(model_name))
self.use_tril_attention_mask = use_tril_attention_mask
self.metric_loss = torchmetrics.MeanMetric()
self.metric_accuracy = torchmetrics.Accuracy(
task='multiclass',
num_classes=self.llm.config.vocab_size,
)
@cache
def get_batch_tril_matrix(
self, block_size: int, batch_size: Optional[int] = None
) -> torch.Tensor:
matrix = torch.ones(block_size, block_size).tril()
if batch_size is not None:
matrix = matrix.repeat(batch_size, 1, 1)
return matrix
def register_core_module(self, module: torch.nn.Module) -> torch.nn.Module:
object.__setattr__(self, '__core_module__', module)
return module
def training_step(self, batch: Dict[str, torch.Tensor], batch_idx):
batch_size, block_size = batch['input_ids'].shape
if self.use_tril_attention_mask:
batch['attention_mask'] = self.get_batch_tril_matrix(
block_size, batch_size=batch_size
).to(self.device)
outputs = self.llm(**batch, return_dict=True)
loss = outputs.loss
self.log('train_loss', loss, rank_zero_only=True)
return loss
def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx):
outputs = self.llm(**batch, return_dict=True)
loss = outputs.loss
logits = outputs.logits[..., :-1, :]
labels = batch['labels'][..., 1:]
self.metric_loss.update(loss)
label_mask = labels != -100
self.metric_accuracy.update(logits[label_mask], labels[label_mask])
def on_validation_epoch_end(self) -> None:
self.log('val_loss', self.metric_loss, rank_zero_only=True)
self.log('accuracy', self.metric_accuracy, rank_zero_only=True)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=0.0001)
return optimizer
def configure_callbacks(self):
checkpoint_callback = pl.callbacks.ModelCheckpoint(
monitor='accuracy',
mode='max',
filename='{epoch:02d}-{accuracy:.4f}',
)
early_stop_callback = pl.callbacks.EarlyStopping(
monitor='accuracy',
min_delta=0.001,
patience=3,
mode='max',
stopping_threshold=1,
)
return [checkpoint_callback, early_stop_callback]
if __name__ == '__main__': if __name__ == '__main__':
args = parse_args() args = parse_args()

41
utils.py Normal file
View File

@ -0,0 +1,41 @@
import os
from typing import Union
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizer,
)
def init_model(model_name: Union[str, os.PathLike]) -> PreTrainedModel:
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
try:
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
except ValueError:
model = AutoModel.from_config(config, trust_remote_code=True)
return model
def load_model(model_name_or_path: Union[str, os.PathLike]) -> PreTrainedModel:
try:
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path, trust_remote_code=True
)
except ValueError:
model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True)
return model
def load_tokenizer(
tokenizer_name_or_path: Union[str, os.PathLike]
) -> PreTrainedTokenizer:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name_or_path, padding_side='left', trust_remote_code=True
)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
return tokenizer