diff --git a/.gitignore b/.gitignore index e460e1a..a8afa61 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,5 @@ +__pycache__ +.ipynb_checkpoints +exports lightning_logs .env diff --git a/generate.py b/generate.py index 3951953..378214e 100644 --- a/generate.py +++ b/generate.py @@ -1,34 +1,10 @@ import argparse -import os -from typing import List, Union +from typing import List import torch -from transformers import ( - AutoModel, - AutoModelForCausalLM, - AutoTokenizer, - PreTrainedModel, - PreTrainedTokenizer, -) +from transformers import PreTrainedModel, PreTrainedTokenizer - -def load_model(model_name_or_path: Union[str, os.PathLike]) -> PreTrainedModel: - try: - model = AutoModelForCausalLM.from_pretrained( - model_name_or_path, trust_remote_code=True - ) - except ValueError: - model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True) - return model - - -def load_tokenizer(model_name_or_path: Union[str, os.PathLike]) -> PreTrainedTokenizer: - tokenizer = AutoTokenizer.from_pretrained( - model_name_or_path, padding_side='left', trust_remote_code=True - ) - if tokenizer.pad_token_id is None: - tokenizer.pad_token = tokenizer.eos_token - return tokenizer +from utils import load_model, load_tokenizer def eval_prompts( diff --git a/lit_module.py b/lit_module.py new file mode 100644 index 0000000..789c58e --- /dev/null +++ b/lit_module.py @@ -0,0 +1,81 @@ +from functools import cache +from typing import Dict, Optional + +import pytorch_lightning as pl +import torch +import torchmetrics + +from utils import init_model + + +class LitModule(pl.LightningModule): + def __init__(self, model_name: str, use_tril_attention_mask: str = False): + super().__init__() + self.save_hyperparameters() + self.llm = self.register_core_module(init_model(model_name)) + self.use_tril_attention_mask = use_tril_attention_mask + self.metric_loss = torchmetrics.MeanMetric() + self.metric_accuracy = torchmetrics.Accuracy( + task='multiclass', + num_classes=self.llm.config.vocab_size, + ) + + @cache + def get_batch_tril_matrix( + self, block_size: int, batch_size: Optional[int] = None + ) -> torch.Tensor: + matrix = torch.ones(block_size, block_size).tril() + if batch_size is not None: + matrix = matrix.repeat(batch_size, 1, 1) + return matrix + + def register_core_module(self, module: torch.nn.Module) -> torch.nn.Module: + object.__setattr__(self, '__core_module__', module) + return module + + def training_step(self, batch: Dict[str, torch.Tensor], batch_idx): + batch_size, block_size = batch['input_ids'].shape + if self.use_tril_attention_mask: + batch['attention_mask'] = self.get_batch_tril_matrix( + block_size, batch_size=batch_size + ).to(self.device) + outputs = self.llm(**batch, return_dict=True) + loss = outputs.loss + + self.log('train_loss', loss, rank_zero_only=True) + + return loss + + def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx): + outputs = self.llm(**batch, return_dict=True) + loss = outputs.loss + logits = outputs.logits[..., :-1, :] + labels = batch['labels'][..., 1:] + + self.metric_loss.update(loss) + + label_mask = labels != -100 + self.metric_accuracy.update(logits[label_mask], labels[label_mask]) + + def on_validation_epoch_end(self) -> None: + self.log('val_loss', self.metric_loss, rank_zero_only=True) + self.log('accuracy', self.metric_accuracy, rank_zero_only=True) + + def configure_optimizers(self): + optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=0.0001) + return optimizer + + def configure_callbacks(self): + checkpoint_callback = pl.callbacks.ModelCheckpoint( + monitor='accuracy', + mode='max', + filename='{epoch:02d}-{accuracy:.4f}', + ) + early_stop_callback = pl.callbacks.EarlyStopping( + monitor='accuracy', + min_delta=0.001, + patience=3, + mode='max', + stopping_threshold=1, + ) + return [checkpoint_callback, early_stop_callback] diff --git a/train.py b/lit_train.py similarity index 63% rename from train.py rename to lit_train.py index b6982dc..6c65561 100644 --- a/train.py +++ b/lit_train.py @@ -1,45 +1,21 @@ import argparse -import os -from functools import cache, partial +from functools import partial from itertools import chain -from typing import Dict, Optional, Tuple, Union +from typing import Dict, Tuple import datasets import pytorch_lightning as pl import torch -import torchmetrics from torch.utils.data import ConcatDataset, DataLoader from transformers import ( - AutoConfig, - AutoModel, - AutoModelForCausalLM, - AutoTokenizer, BatchEncoding, DefaultDataCollator, - PreTrainedModel, PreTrainedTokenizer, set_seed, ) - -def init_model(model_name: Union[str, os.PathLike]) -> PreTrainedModel: - config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) - try: - model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) - except ValueError: - model = AutoModel.from_config(config, trust_remote_code=True) - return model - - -def load_tokenizer( - tokenizer_name_or_path: Union[str, os.PathLike] -) -> PreTrainedTokenizer: - tokenizer = AutoTokenizer.from_pretrained( - tokenizer_name_or_path, padding_side='left', trust_remote_code=True - ) - if tokenizer.pad_token_id is None: - tokenizer.pad_token = tokenizer.eos_token - return tokenizer +from lit_module import LitModule +from utils import load_tokenizer def split_raw_dataset( @@ -169,79 +145,6 @@ def parse_args(): return args -class LitModule(pl.LightningModule): - def __init__(self, model_name: str, use_tril_attention_mask: str = False): - super().__init__() - self.save_hyperparameters() - self.llm = self.register_core_module(init_model(model_name)) - self.use_tril_attention_mask = use_tril_attention_mask - self.metric_loss = torchmetrics.MeanMetric() - self.metric_accuracy = torchmetrics.Accuracy( - task='multiclass', - num_classes=self.llm.config.vocab_size, - ) - - @cache - def get_batch_tril_matrix( - self, block_size: int, batch_size: Optional[int] = None - ) -> torch.Tensor: - matrix = torch.ones(block_size, block_size).tril() - if batch_size is not None: - matrix = matrix.repeat(batch_size, 1, 1) - return matrix - - def register_core_module(self, module: torch.nn.Module) -> torch.nn.Module: - object.__setattr__(self, '__core_module__', module) - return module - - def training_step(self, batch: Dict[str, torch.Tensor], batch_idx): - batch_size, block_size = batch['input_ids'].shape - if self.use_tril_attention_mask: - batch['attention_mask'] = self.get_batch_tril_matrix( - block_size, batch_size=batch_size - ).to(self.device) - outputs = self.llm(**batch, return_dict=True) - loss = outputs.loss - - self.log('train_loss', loss, rank_zero_only=True) - - return loss - - def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx): - outputs = self.llm(**batch, return_dict=True) - loss = outputs.loss - logits = outputs.logits[..., :-1, :] - labels = batch['labels'][..., 1:] - - self.metric_loss.update(loss) - - label_mask = labels != -100 - self.metric_accuracy.update(logits[label_mask], labels[label_mask]) - - def on_validation_epoch_end(self) -> None: - self.log('val_loss', self.metric_loss, rank_zero_only=True) - self.log('accuracy', self.metric_accuracy, rank_zero_only=True) - - def configure_optimizers(self): - optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=0.0001) - return optimizer - - def configure_callbacks(self): - checkpoint_callback = pl.callbacks.ModelCheckpoint( - monitor='accuracy', - mode='max', - filename='{epoch:02d}-{accuracy:.4f}', - ) - early_stop_callback = pl.callbacks.EarlyStopping( - monitor='accuracy', - min_delta=0.001, - patience=3, - mode='max', - stopping_threshold=1, - ) - return [checkpoint_callback, early_stop_callback] - - if __name__ == '__main__': args = parse_args() diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..5675bd4 --- /dev/null +++ b/utils.py @@ -0,0 +1,41 @@ +import os +from typing import Union + +from transformers import ( + AutoConfig, + AutoModel, + AutoModelForCausalLM, + AutoTokenizer, + PreTrainedModel, + PreTrainedTokenizer, +) + + +def init_model(model_name: Union[str, os.PathLike]) -> PreTrainedModel: + config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + try: + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + except ValueError: + model = AutoModel.from_config(config, trust_remote_code=True) + return model + + +def load_model(model_name_or_path: Union[str, os.PathLike]) -> PreTrainedModel: + try: + model = AutoModelForCausalLM.from_pretrained( + model_name_or_path, trust_remote_code=True + ) + except ValueError: + model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True) + return model + + +def load_tokenizer( + tokenizer_name_or_path: Union[str, os.PathLike] +) -> PreTrainedTokenizer: + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name_or_path, padding_side='left', trust_remote_code=True + ) + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + return tokenizer