[code] refactor
This commit is contained in:
parent
939be31c10
commit
09507449f7
|
@ -1,2 +1,5 @@
|
||||||
|
__pycache__
|
||||||
|
.ipynb_checkpoints
|
||||||
|
exports
|
||||||
lightning_logs
|
lightning_logs
|
||||||
.env
|
.env
|
||||||
|
|
30
generate.py
30
generate.py
|
@ -1,34 +1,10 @@
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
from typing import List
|
||||||
from typing import List, Union
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import (
|
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||||
AutoModel,
|
|
||||||
AutoModelForCausalLM,
|
|
||||||
AutoTokenizer,
|
|
||||||
PreTrainedModel,
|
|
||||||
PreTrainedTokenizer,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
from utils import load_model, load_tokenizer
|
||||||
def load_model(model_name_or_path: Union[str, os.PathLike]) -> PreTrainedModel:
|
|
||||||
try:
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_name_or_path, trust_remote_code=True
|
|
||||||
)
|
|
||||||
except ValueError:
|
|
||||||
model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def load_tokenizer(model_name_or_path: Union[str, os.PathLike]) -> PreTrainedTokenizer:
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_name_or_path, padding_side='left', trust_remote_code=True
|
|
||||||
)
|
|
||||||
if tokenizer.pad_token_id is None:
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
return tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def eval_prompts(
|
def eval_prompts(
|
||||||
|
|
|
@ -0,0 +1,81 @@
|
||||||
|
from functools import cache
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
import torchmetrics
|
||||||
|
|
||||||
|
from utils import init_model
|
||||||
|
|
||||||
|
|
||||||
|
class LitModule(pl.LightningModule):
|
||||||
|
def __init__(self, model_name: str, use_tril_attention_mask: str = False):
|
||||||
|
super().__init__()
|
||||||
|
self.save_hyperparameters()
|
||||||
|
self.llm = self.register_core_module(init_model(model_name))
|
||||||
|
self.use_tril_attention_mask = use_tril_attention_mask
|
||||||
|
self.metric_loss = torchmetrics.MeanMetric()
|
||||||
|
self.metric_accuracy = torchmetrics.Accuracy(
|
||||||
|
task='multiclass',
|
||||||
|
num_classes=self.llm.config.vocab_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def get_batch_tril_matrix(
|
||||||
|
self, block_size: int, batch_size: Optional[int] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
matrix = torch.ones(block_size, block_size).tril()
|
||||||
|
if batch_size is not None:
|
||||||
|
matrix = matrix.repeat(batch_size, 1, 1)
|
||||||
|
return matrix
|
||||||
|
|
||||||
|
def register_core_module(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||||
|
object.__setattr__(self, '__core_module__', module)
|
||||||
|
return module
|
||||||
|
|
||||||
|
def training_step(self, batch: Dict[str, torch.Tensor], batch_idx):
|
||||||
|
batch_size, block_size = batch['input_ids'].shape
|
||||||
|
if self.use_tril_attention_mask:
|
||||||
|
batch['attention_mask'] = self.get_batch_tril_matrix(
|
||||||
|
block_size, batch_size=batch_size
|
||||||
|
).to(self.device)
|
||||||
|
outputs = self.llm(**batch, return_dict=True)
|
||||||
|
loss = outputs.loss
|
||||||
|
|
||||||
|
self.log('train_loss', loss, rank_zero_only=True)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx):
|
||||||
|
outputs = self.llm(**batch, return_dict=True)
|
||||||
|
loss = outputs.loss
|
||||||
|
logits = outputs.logits[..., :-1, :]
|
||||||
|
labels = batch['labels'][..., 1:]
|
||||||
|
|
||||||
|
self.metric_loss.update(loss)
|
||||||
|
|
||||||
|
label_mask = labels != -100
|
||||||
|
self.metric_accuracy.update(logits[label_mask], labels[label_mask])
|
||||||
|
|
||||||
|
def on_validation_epoch_end(self) -> None:
|
||||||
|
self.log('val_loss', self.metric_loss, rank_zero_only=True)
|
||||||
|
self.log('accuracy', self.metric_accuracy, rank_zero_only=True)
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=0.0001)
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
def configure_callbacks(self):
|
||||||
|
checkpoint_callback = pl.callbacks.ModelCheckpoint(
|
||||||
|
monitor='accuracy',
|
||||||
|
mode='max',
|
||||||
|
filename='{epoch:02d}-{accuracy:.4f}',
|
||||||
|
)
|
||||||
|
early_stop_callback = pl.callbacks.EarlyStopping(
|
||||||
|
monitor='accuracy',
|
||||||
|
min_delta=0.001,
|
||||||
|
patience=3,
|
||||||
|
mode='max',
|
||||||
|
stopping_threshold=1,
|
||||||
|
)
|
||||||
|
return [checkpoint_callback, early_stop_callback]
|
|
@ -1,45 +1,21 @@
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
from functools import partial
|
||||||
from functools import cache, partial
|
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
import torchmetrics
|
|
||||||
from torch.utils.data import ConcatDataset, DataLoader
|
from torch.utils.data import ConcatDataset, DataLoader
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
|
||||||
AutoModel,
|
|
||||||
AutoModelForCausalLM,
|
|
||||||
AutoTokenizer,
|
|
||||||
BatchEncoding,
|
BatchEncoding,
|
||||||
DefaultDataCollator,
|
DefaultDataCollator,
|
||||||
PreTrainedModel,
|
|
||||||
PreTrainedTokenizer,
|
PreTrainedTokenizer,
|
||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from lit_module import LitModule
|
||||||
def init_model(model_name: Union[str, os.PathLike]) -> PreTrainedModel:
|
from utils import load_tokenizer
|
||||||
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
|
||||||
try:
|
|
||||||
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
|
|
||||||
except ValueError:
|
|
||||||
model = AutoModel.from_config(config, trust_remote_code=True)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def load_tokenizer(
|
|
||||||
tokenizer_name_or_path: Union[str, os.PathLike]
|
|
||||||
) -> PreTrainedTokenizer:
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
tokenizer_name_or_path, padding_side='left', trust_remote_code=True
|
|
||||||
)
|
|
||||||
if tokenizer.pad_token_id is None:
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
return tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def split_raw_dataset(
|
def split_raw_dataset(
|
||||||
|
@ -169,79 +145,6 @@ def parse_args():
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
class LitModule(pl.LightningModule):
|
|
||||||
def __init__(self, model_name: str, use_tril_attention_mask: str = False):
|
|
||||||
super().__init__()
|
|
||||||
self.save_hyperparameters()
|
|
||||||
self.llm = self.register_core_module(init_model(model_name))
|
|
||||||
self.use_tril_attention_mask = use_tril_attention_mask
|
|
||||||
self.metric_loss = torchmetrics.MeanMetric()
|
|
||||||
self.metric_accuracy = torchmetrics.Accuracy(
|
|
||||||
task='multiclass',
|
|
||||||
num_classes=self.llm.config.vocab_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
@cache
|
|
||||||
def get_batch_tril_matrix(
|
|
||||||
self, block_size: int, batch_size: Optional[int] = None
|
|
||||||
) -> torch.Tensor:
|
|
||||||
matrix = torch.ones(block_size, block_size).tril()
|
|
||||||
if batch_size is not None:
|
|
||||||
matrix = matrix.repeat(batch_size, 1, 1)
|
|
||||||
return matrix
|
|
||||||
|
|
||||||
def register_core_module(self, module: torch.nn.Module) -> torch.nn.Module:
|
|
||||||
object.__setattr__(self, '__core_module__', module)
|
|
||||||
return module
|
|
||||||
|
|
||||||
def training_step(self, batch: Dict[str, torch.Tensor], batch_idx):
|
|
||||||
batch_size, block_size = batch['input_ids'].shape
|
|
||||||
if self.use_tril_attention_mask:
|
|
||||||
batch['attention_mask'] = self.get_batch_tril_matrix(
|
|
||||||
block_size, batch_size=batch_size
|
|
||||||
).to(self.device)
|
|
||||||
outputs = self.llm(**batch, return_dict=True)
|
|
||||||
loss = outputs.loss
|
|
||||||
|
|
||||||
self.log('train_loss', loss, rank_zero_only=True)
|
|
||||||
|
|
||||||
return loss
|
|
||||||
|
|
||||||
def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx):
|
|
||||||
outputs = self.llm(**batch, return_dict=True)
|
|
||||||
loss = outputs.loss
|
|
||||||
logits = outputs.logits[..., :-1, :]
|
|
||||||
labels = batch['labels'][..., 1:]
|
|
||||||
|
|
||||||
self.metric_loss.update(loss)
|
|
||||||
|
|
||||||
label_mask = labels != -100
|
|
||||||
self.metric_accuracy.update(logits[label_mask], labels[label_mask])
|
|
||||||
|
|
||||||
def on_validation_epoch_end(self) -> None:
|
|
||||||
self.log('val_loss', self.metric_loss, rank_zero_only=True)
|
|
||||||
self.log('accuracy', self.metric_accuracy, rank_zero_only=True)
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
|
||||||
optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=0.0001)
|
|
||||||
return optimizer
|
|
||||||
|
|
||||||
def configure_callbacks(self):
|
|
||||||
checkpoint_callback = pl.callbacks.ModelCheckpoint(
|
|
||||||
monitor='accuracy',
|
|
||||||
mode='max',
|
|
||||||
filename='{epoch:02d}-{accuracy:.4f}',
|
|
||||||
)
|
|
||||||
early_stop_callback = pl.callbacks.EarlyStopping(
|
|
||||||
monitor='accuracy',
|
|
||||||
min_delta=0.001,
|
|
||||||
patience=3,
|
|
||||||
mode='max',
|
|
||||||
stopping_threshold=1,
|
|
||||||
)
|
|
||||||
return [checkpoint_callback, early_stop_callback]
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
|
@ -0,0 +1,41 @@
|
||||||
|
import os
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
AutoConfig,
|
||||||
|
AutoModel,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoTokenizer,
|
||||||
|
PreTrainedModel,
|
||||||
|
PreTrainedTokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def init_model(model_name: Union[str, os.PathLike]) -> PreTrainedModel:
|
||||||
|
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||||
|
try:
|
||||||
|
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
|
||||||
|
except ValueError:
|
||||||
|
model = AutoModel.from_config(config, trust_remote_code=True)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(model_name_or_path: Union[str, os.PathLike]) -> PreTrainedModel:
|
||||||
|
try:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name_or_path, trust_remote_code=True
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_tokenizer(
|
||||||
|
tokenizer_name_or_path: Union[str, os.PathLike]
|
||||||
|
) -> PreTrainedTokenizer:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
tokenizer_name_or_path, padding_side='left', trust_remote_code=True
|
||||||
|
)
|
||||||
|
if tokenizer.pad_token_id is None:
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
return tokenizer
|
Loading…
Reference in New Issue