From 45fa065530646e48b8bb916c65e4a064e28844ff Mon Sep 17 00:00:00 2001 From: Yiqing-Zhou Date: Thu, 4 May 2023 21:52:25 +0800 Subject: [PATCH] Initial Commit --- .gitignore | 2 + .vscode/launch.json | 28 +++++ generate.py | 94 +++++++++++++++ requirements.txt | 5 + train.py | 288 ++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 417 insertions(+) create mode 100644 .gitignore create mode 100644 .vscode/launch.json create mode 100644 generate.py create mode 100644 requirements.txt create mode 100644 train.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e460e1a --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +lightning_logs +.env diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..aa0cbae --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,28 @@ +{ + // 使用 IntelliSense 了解相关属性。 + // 悬停以查看现有属性的描述。 + // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python: train", + "type": "python", + "request": "launch", + "program": "train.py", + "args": [ + "--dataset_name", "wikitext:wikitext-2-v1", + ], + "console": "integratedTerminal", + "justMyCode": true + }, + { + "name": "Python: generate", + "type": "python", + "request": "launch", + "program": "generate.py", + "args": [], + "console": "integratedTerminal", + "justMyCode": true + } + ] +} \ No newline at end of file diff --git a/generate.py b/generate.py new file mode 100644 index 0000000..ee39652 --- /dev/null +++ b/generate.py @@ -0,0 +1,94 @@ +import argparse +import os +from typing import List, Union + +import torch +from transformers import ( + AutoModel, + AutoModelForCausalLM, + AutoTokenizer, + PreTrainedModel, + PreTrainedTokenizer, +) + + +def load_model(model_name_or_path: Union[str, os.PathLike]) -> PreTrainedModel: + try: + model = AutoModelForCausalLM.from_pretrained( + model_name_or_path, trust_remote_code=True + ) + except ValueError: + model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True) + return model + + +def load_tokenizer(model_name_or_path: Union[str, os.PathLike]) -> PreTrainedTokenizer: + tokenizer = AutoTokenizer.from_pretrained( + model_name_or_path, padding_side='left', trust_remote_code=True + ) + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + + +def eval_prompts( + model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prompts: List[str] +) -> List[str]: + inputs = tokenizer( + prompts, padding=True, return_tensors='pt', return_attention_mask=True + ) + inputs['position_ids'] = inputs.attention_mask.cumsum(-1) - 1 + inputs['position_ids'].masked_fill_(inputs.attention_mask == 0, 1) + inputs['attention_mask'] = ( + inputs.attention_mask.unsqueeze(1) * inputs.attention_mask.unsqueeze(2) + ).tril() + inputs = inputs.to(model.device) + with torch.inference_mode(): + output_ids = model.generate( + **inputs, + do_sample=False, + num_beams=16, + max_new_tokens=100, + eos_token_id=tokenizer.eos_token_id, + early_stopping=True, + ) + completes = tokenizer.batch_decode( + output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + return completes + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_name_or_path", + type=str, + help="Name of or path to model", + default='gpt2', + ) + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_args() + + device = torch.device(0) + + model = load_model(args.model_name_or_path) + tokenizer = load_tokenizer(args.model_name_or_path) + + model = model.to(device) + prompts = [ + "Shall I compare thee to a summer's day? Thou art more lovely and more temperate.", + "Shall I compare thee to a summer's day? Thou art more lovely and", + "Belle! C'est un mot qu'on dirait inventé pour elle.", + "Belle! C'est un mot qu'on dirait inventé", + "这是一个最好的时代,这是一个最坏的时代。", + "这是一个最好的时代,这是一个最坏的", + ] + completes = eval_prompts(model, tokenizer, prompts) + + for prompt, complete in zip(prompts, completes): + print("[p]", prompt) + print("[c]", complete) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..08a6d2c --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +datasets==2.11.0 +deepspeed==0.9.1 +pytorch-lightning==2.0.2 +torch==2.0.0 +transformers==4.28.1 diff --git a/train.py b/train.py new file mode 100644 index 0000000..ef7ded8 --- /dev/null +++ b/train.py @@ -0,0 +1,288 @@ +import argparse +import os +from functools import cache, partial +from itertools import chain +from typing import Dict, Optional, Tuple, Union + +import datasets +import pytorch_lightning as pl +import torch +import torchmetrics +from pytorch_lightning import strategies +from torch.utils.data import ConcatDataset, DataLoader +from transformers import ( + AutoConfig, + AutoModel, + AutoModelForCausalLM, + AutoTokenizer, + BatchEncoding, + DefaultDataCollator, + PreTrainedModel, + PreTrainedTokenizer, + set_seed, +) + + +def init_model(model_name: Union[str, os.PathLike]) -> PreTrainedModel: + config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + try: + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + except ValueError: + model = AutoModel.from_config(config, trust_remote_code=True) + return model + + +def load_tokenizer( + tokenizer_name_or_path: Union[str, os.PathLike] +) -> PreTrainedTokenizer: + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name_or_path, padding_side='left', trust_remote_code=True + ) + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + + +def split_raw_dataset( + raw_dataset: datasets.DatasetDict, +) -> Tuple[datasets.Dataset, datasets.Dataset]: + if 'validation' in raw_dataset: + train_dataset, val_dataset = raw_dataset['train'], raw_dataset['validation'] + else: + raw_dataset = raw_dataset['train'].train_test_split( + test_size=0.05, seed=args.seed + ) + train_dataset, val_dataset = raw_dataset['train'], raw_dataset['test'] + return train_dataset, val_dataset + + +def process_dataset( + dataset: datasets.Dataset, tokenizer: PreTrainedTokenizer +) -> datasets.Dataset: + def group_texts(examples: Dict[str, list], block_size: int = 512) -> BatchEncoding: + concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} + total_length = len(concatenated_examples[list(examples.keys())[0]]) + total_length = (total_length // block_size) * block_size + result = { + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() + } + result['labels'] = result['input_ids'].copy() + result = BatchEncoding(result) + return result + + def tokenize_inputs( + examples: Dict[str, list], + tokenizer: PreTrainedTokenizer, + column_name: str = 'text', + ) -> BatchEncoding: + return tokenizer(examples[column_name], return_attention_mask=False) + + dataset_column_names = list(dataset.features) + dataset = dataset.map( + partial( + tokenize_inputs, + tokenizer=tokenizer, + column_name=dataset_column_names[0], + ), + batched=True, + num_proc=args.num_proc, + remove_columns=dataset_column_names, + ).map( + partial(group_texts, block_size=tokenizer.model_max_length), + batched=True, + num_proc=args.num_proc, + ) + return dataset + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_name", + type=str, + help="Name of or path to model", + default='gpt2', + ) + parser.add_argument("--fp16", help="Enable fp16", action="store_true") + parser.add_argument("--bf16", help="Enable bf16", action="store_true") + parser.add_argument( + "--tokenizer_name_or_path", + type=str, + help="Name of or path to tokenizer", + default=None, + ) + parser.add_argument( + "--dataset_name", + nargs='+', + type=str, + help="Name(s) of dataset. To specify a config, pass a :", + default=["wikitext:wikitext-2-v1"], + ) + parser.add_argument( + "--train_batch_size", + type=int, + help="Batch size of training", + default=8, + ) + parser.add_argument( + "--val_batch_size", + type=int, + help="Batch size of validating", + default=16, + ) + parser.add_argument( + "--num_proc", + type=str, + help="Number of data processes", + default=16, + ) + parser.add_argument( + "--resume_from_ckpt_path", + type=str, + help="Checkpoint file path to resume from", + default=None, + ) + parser.add_argument( + "--seed", + type=str, + help="Random seed", + default=42, + ) + args = parser.parse_args() + return args + + +class LitModule(pl.LightningModule): + def __init__(self, model_name: str): + super().__init__() + self.save_hyperparameters() + self.llm = self.register_core_module(init_model(model_name)) + self.metric_loss = torchmetrics.MeanMetric() + self.metric_accuracy = torchmetrics.Accuracy( + task='multiclass', + num_classes=self.llm.config.vocab_size, + ) + + @cache + def get_tril_matrix( + self, block_size: int, batch_size: Optional[int] = None + ) -> torch.Tensor: + matrix = torch.ones(block_size, block_size).tril() + if batch_size is not None: + matrix = matrix.repeat(batch_size, 1, 1) + return matrix + + def register_core_module(self, module: torch.nn.Module) -> torch.nn.Module: + object.__setattr__(self, '__core_module__', module) + return module + + def training_step(self, batch: Dict[str, torch.Tensor], batch_idx): + batch_size, block_size = batch['input_ids'].shape + batch['attention_mask'] = self.get_tril_matrix( + block_size, batch_size=batch_size + ).to(self.device) + outputs = self.llm(**batch, return_dict=True) + loss = outputs.loss + + self.log('train_loss', loss, rank_zero_only=True) + + return loss + + def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx): + outputs = self.llm(**batch, return_dict=True) + loss = outputs.loss + logits = outputs.logits[..., :-1, :] + labels = batch['labels'][..., 1:] + + self.metric_loss.update(loss) + + label_mask = labels != -100 + self.metric_accuracy.update(logits[label_mask], labels[label_mask]) + + def on_validation_epoch_end(self) -> None: + self.log('val_loss', self.metric_loss, rank_zero_only=True) + self.log('accuracy', self.metric_accuracy, rank_zero_only=True) + + def configure_optimizers(self): + optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=0.0001) + return optimizer + + def configure_callbacks(self): + checkpoint_callback = pl.callbacks.ModelCheckpoint( + monitor='accuracy', + mode='max', + filename='{epoch:02d}-{accuracy:.4f}', + ) + early_stop_callback = pl.callbacks.EarlyStopping( + monitor='accuracy', + min_delta=0.001, + patience=3, + mode='max', + stopping_threshold=1, + ) + return [checkpoint_callback, early_stop_callback] + + +if __name__ == '__main__': + args = parse_args() + + if args.tokenizer_name_or_path is None: + args.tokenizer_name_or_path = args.model_name + + set_seed(args.seed) + + # lightning module + lit_module = LitModule(args.model_name) + + # datasets + tokenizer = load_tokenizer(args.tokenizer_name_or_path) + train_dataset_list = [] + val_dataset_list = [] + for dataset_name in args.dataset_name: + dataset_args = dataset_name.split(':') + raw_dataset = datasets.load_dataset(*dataset_args) + train_dataset, val_dataset = split_raw_dataset(raw_dataset) + train_dataset = process_dataset(train_dataset, tokenizer) + val_dataset = process_dataset(val_dataset, tokenizer) + train_dataset_list.append(train_dataset) + val_dataset_list.append(val_dataset) + train_dataset = ConcatDataset(train_dataset_list) + val_dataset = ConcatDataset(val_dataset_list) + + # dataloaders + train_dataloader = DataLoader( + train_dataset, + batch_size=args.train_batch_size, + num_workers=args.num_proc, + collate_fn=DefaultDataCollator(), + shuffle=True, + ) + val_dataloader = DataLoader( + val_dataset, + batch_size=args.val_batch_size, + num_workers=args.num_proc, + collate_fn=DefaultDataCollator(), + ) + + # trainer + torch.set_float32_matmul_precision('medium') + if args.bf16: + precision = 'bf16-mixed' + elif args.fp16: + precision = '16-mixed' + else: + precision = "32-true" + lit_trainer = pl.Trainer( + accelerator='gpu', + precision=precision, + log_every_n_steps=5, + accumulate_grad_batches=32, + strategy=strategies.DeepSpeedStrategy(stage=3), + ) + lit_trainer.fit( + lit_module, + train_dataloaders=train_dataloader, + val_dataloaders=val_dataloader, + ckpt_path=args.resume_from_ckpt_path, + )