From e5f97af291579f8b90a93b81e02aaf76be45fbe5 Mon Sep 17 00:00:00 2001 From: Colin Date: Sun, 25 Feb 2024 20:20:32 +0800 Subject: [PATCH] Add wit train support. --- .gitignore | 3 +- dataset/MNBVC.py | 5 -- wit/lit_module.py | 100 ++++++++++++++++++++++++ wit/lit_train.py | 164 +++++++++++++++++++++++++++++++++++++++ wit/modeling_wit.py | 20 ++--- wit/tokenization_qwen.py | 2 +- 6 files changed, 278 insertions(+), 16 deletions(-) delete mode 100644 dataset/MNBVC.py create mode 100644 wit/lit_module.py create mode 100644 wit/lit_train.py diff --git a/.gitignore b/.gitignore index 6cce550..fbafd7b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ __pycache__ .vscode *.txt -temp \ No newline at end of file +temp +lightning_logs \ No newline at end of file diff --git a/dataset/MNBVC.py b/dataset/MNBVC.py deleted file mode 100644 index 2ec3258..0000000 --- a/dataset/MNBVC.py +++ /dev/null @@ -1,5 +0,0 @@ -from datasets import load_dataset - -dataset = load_dataset("liwu/MNBVC", "wikipedia", split="train", streaming=True) - -print(next(iter(dataset))) # get the first line diff --git a/wit/lit_module.py b/wit/lit_module.py new file mode 100644 index 0000000..3f8d2f1 --- /dev/null +++ b/wit/lit_module.py @@ -0,0 +1,100 @@ +from functools import cache +from typing import Dict, Optional + +import pytorch_lightning as pl +import torch +import torchmetrics + +# from utils import init_model +# from custom_models.gpt2.modeling_gpt2 import GPT2LMHeadModel + +from modeling_wit import QWenLMHeadModel +from configuration_qwen import QWenConfig + +from transformers import AutoConfig + + +class LitModule(pl.LightningModule): + def __init__( + self, + model_dir: str, + learning_rate: float = 0.0001, + use_tril_attention_mask: str = False, + ): + super().__init__() + self.save_hyperparameters() + config = QWenConfig() + model = QWenLMHeadModel(config) + model = model.from_pretrained(model_dir) + self.llm = self.register_core_module(model) + self.learning_rate = learning_rate + self.use_tril_attention_mask = use_tril_attention_mask + self.metric_loss = torchmetrics.MeanMetric() + self.metric_accuracy = torchmetrics.Accuracy( + task="multiclass", + num_classes=self.llm.config.vocab_size, + ) + + @cache + def get_batch_tril_matrix(self, block_size: int, batch_size: Optional[int] = None) -> torch.Tensor: + matrix = torch.ones(block_size, block_size).tril() + if batch_size is not None: + matrix = matrix.repeat(batch_size, 1, 1) + return matrix + + def register_core_module(self, module: torch.nn.Module) -> torch.nn.Module: + object.__setattr__(self, "__core_module__", module) + return module + + def training_step(self, batch: Dict[str, torch.Tensor], batch_idx): + batch_size, block_size = batch["input_ids"].shape + if self.use_tril_attention_mask: + batch["attention_mask"] = self.get_batch_tril_matrix(block_size, batch_size=batch_size).to(self.device) + outputs, loss = self.llm(**batch) + self.log("train_loss", loss, rank_zero_only=True) + return loss + + def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx): + outputs, loss = self.llm(**batch, return_dict=True) + logits = outputs[..., :-1, :] + labels = batch["labels"][..., 1:] + + self.metric_loss.update(loss) + + label_mask = labels != -100 + self.metric_accuracy.update(logits[label_mask], labels[label_mask]) + + def on_validation_epoch_end(self) -> None: + self.log("val_loss", self.metric_loss, rank_zero_only=True) + self.log("accuracy", self.metric_accuracy, rank_zero_only=True) + + def configure_optimizers(self): + strategy = self.trainer.strategy + if isinstance(strategy, pl.strategies.DeepSpeedStrategy): + assert "optimizer" not in strategy.config + zero_config = strategy.config.get("zero_optimization") + if zero_config is not None: + if "offload_optimizer" in zero_config: + import deepspeed + + optimizer = deepspeed.ops.adam.DeepSpeedCPUAdam( + self.trainer.model.parameters(), lr=self.learning_rate + ) + return optimizer + optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=self.learning_rate) + return optimizer + + def configure_callbacks(self): + checkpoint_callback = pl.callbacks.ModelCheckpoint( + monitor="accuracy", + mode="max", + filename="{epoch:02d}-{accuracy:.4f}", + ) + early_stop_callback = pl.callbacks.EarlyStopping( + monitor="accuracy", + min_delta=0.001, + patience=3, + mode="max", + stopping_threshold=1, + ) + return [checkpoint_callback, early_stop_callback] diff --git a/wit/lit_train.py b/wit/lit_train.py new file mode 100644 index 0000000..fd52bab --- /dev/null +++ b/wit/lit_train.py @@ -0,0 +1,164 @@ +import argparse +from functools import partial +from itertools import chain +from typing import Dict, Tuple + +import datasets +import pytorch_lightning as pl +import torch +from torch.utils.data import ConcatDataset, DataLoader +from transformers import ( + BatchEncoding, + DefaultDataCollator, + PreTrainedTokenizer, + set_seed, +) +from modelscope import snapshot_download +from lit_module import LitModule +from tokenization_qwen import QWenTokenizer + +model_name = "qwen/Qwen-1_8B-Chat" +learning_rate = 0.0001 +use_tril_attention_mask = None +precision = "16-mixed" # "precision:bf16-mixed,16-mixed,32-true" +tokenizer_name_or_path = None +dataset_name = "/home/colin/develop/dataset/liwu/MNBVC/wiki/20230197/0.jsonl.gz" +dataset_name = "/home/colin/develop/dataset/liwu/MNBVC/wiki" +train_batch_size = 8 +val_batch_size = 1 +accumulate_grad_batches = 32 +num_proc = 8 +max_epochs = None +strategy = "fsdp" +resume_from_ckpt_path = None +seed = 42 + + +def split_raw_dataset( + raw_dataset: datasets.DatasetDict, +) -> Tuple[datasets.Dataset, datasets.Dataset]: + if "validation" in raw_dataset: + train_dataset, val_dataset = raw_dataset["train"], raw_dataset["validation"] + else: + raw_dataset = raw_dataset["train"].train_test_split(test_size=0.05, seed=seed) + train_dataset, val_dataset = raw_dataset["train"], raw_dataset["test"] + return train_dataset, val_dataset + + +def process_dataset(dataset: datasets.Dataset, tokenizer: PreTrainedTokenizer) -> datasets.Dataset: + def group_texts(examples: Dict[str, list], block_size: int = 512) -> BatchEncoding: + concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} + total_length = len(concatenated_examples[list(examples.keys())[0]]) + total_length = (total_length // block_size) * block_size + result = { + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() + } + result["labels"] = result["input_ids"].copy() + result = BatchEncoding(result) + return result + + def format_inputs(examples): + p = examples["段落"] + mergeLine = "" + for line in p: + mergeLine += line["内容"] + "\n" + return {"text": mergeLine} + + def tokenize_inputs( + examples: Dict[str, list], + tokenizer: PreTrainedTokenizer, + column_name: str = "text", + ) -> BatchEncoding: + logits = tokenizer(examples[column_name], return_attention_mask=False) + return logits + + dataset_column_names = list(dataset.features) + dataset = dataset.map( + partial(format_inputs), + batched=False, + num_proc=num_proc, + remove_columns=dataset_column_names, + ) + dataset_column_names = list(dataset.features) + dataset = dataset.map( + partial(tokenize_inputs, tokenizer=tokenizer), + batched=True, + num_proc=num_proc, + remove_columns=dataset_column_names, + ) + dataset = dataset.map( + partial(group_texts, block_size=tokenizer.model_max_length), + batched=True, + num_proc=num_proc, + ) + + return dataset + + +if __name__ == "__main__": + if tokenizer_name_or_path is None: + tokenizer_name_or_path = model_name + + set_seed(seed) + + # lightning module + model_dir = snapshot_download(model_name) + lit_module = LitModule(model_dir, learning_rate, use_tril_attention_mask) + + # datasets + # tokenizer = load_tokenizer("./custom_models/gpt2") + tokenizer = QWenTokenizer("./wit_b64.tiktoken", "./wit_char.tiktoken") + train_dataset_list = [] + val_dataset_list = [] + for dataset_name in dataset_name: + dataset_args = dataset_name.split(":") + raw_dataset = datasets.load_dataset( + "json", data_files="/home/colin/develop/dataset/liwu/MNBVC/wiki/20230197/0.jsonl.gz" + ) + # raw_dataset = datasets.load_dataset(*dataset_args) + train_dataset, val_dataset = split_raw_dataset(raw_dataset) + train_dataset = process_dataset(train_dataset, tokenizer) + val_dataset = process_dataset(val_dataset, tokenizer) + train_dataset_list.append(train_dataset) + val_dataset_list.append(val_dataset) + train_dataset = ConcatDataset(train_dataset_list) + val_dataset = ConcatDataset(val_dataset_list) + + # dataloaders + train_dataloader = DataLoader( + train_dataset, + batch_size=train_batch_size, + num_workers=num_proc, + collate_fn=DefaultDataCollator(), + persistent_workers=True, + shuffle=True, + ) + val_dataloader = DataLoader( + val_dataset, + batch_size=val_batch_size, + num_workers=num_proc, + collate_fn=DefaultDataCollator(), + persistent_workers=True, + ) + + ne = next(train_dataloader._get_iterator()) + + # trainer + # apply_all_patches() + torch.set_float32_matmul_precision("medium") + precision = precision + lit_trainer = pl.Trainer( + accelerator="gpu", + precision=precision, + log_every_n_steps=5, + accumulate_grad_batches=accumulate_grad_batches, + strategy=strategy, + max_epochs=max_epochs, + ) + lit_trainer.fit( + lit_module, + train_dataloaders=train_dataloader, + val_dataloaders=val_dataloader, + ckpt_path=resume_from_ckpt_path, + ) diff --git a/wit/modeling_wit.py b/wit/modeling_wit.py index d7197d4..f01a725 100644 --- a/wit/modeling_wit.py +++ b/wit/modeling_wit.py @@ -137,6 +137,16 @@ class QWenLMHeadModel(nn.Module): self.transformer = QWenModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + **kwargs, + ): + runner = QwenRunner(self) + return runner.forwardQWen(input_ids, labels) + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]]): pretrained_model_name_or_path = str(pretrained_model_name_or_path) resolved_archive_file = os.path.join(pretrained_model_name_or_path, "model.safetensors.index.json") @@ -343,15 +353,7 @@ class QwenRunner: loss_fct = CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - # shift_labels = torch.ones([1,19]).to(lm_logits.device).to(torch.int64) - # shift_logits = lm_logits[..., :-1, :].contiguous() - # loss_fct = CrossEntropyLoss() - # loss = loss_fct( - # shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) - # ) - # loss.backward() - - return lm_logits + return lm_logits, loss def prepareInput(self, tokenizer, query, query_assistant, history, system): return make_context(tokenizer, query, query_assistant, history=history, system=system) diff --git a/wit/tokenization_qwen.py b/wit/tokenization_qwen.py index d5061d0..fceff38 100644 --- a/wit/tokenization_qwen.py +++ b/wit/tokenization_qwen.py @@ -63,7 +63,7 @@ class QWenTokenizer(PreTrainedTokenizer): self.mergeable_ranks = _load_tiktoken_b64(vocab_file_b64) self.mergeable_ranks.update(_load_tiktoken_char(vocab_file_char, len(self.mergeable_ranks))) - + self.model_max_length = 1024 special = ( "user", "assistant",