Add wit train support.
This commit is contained in:
		
							parent
							
								
									fc071dce70
								
							
						
					
					
						commit
						e5f97af291
					
				|  | @ -2,3 +2,4 @@ __pycache__ | |||
| .vscode | ||||
| *.txt | ||||
| temp | ||||
| lightning_logs | ||||
|  | @ -1,5 +0,0 @@ | |||
| from datasets import load_dataset | ||||
| 
 | ||||
| dataset = load_dataset("liwu/MNBVC", "wikipedia", split="train", streaming=True) | ||||
| 
 | ||||
| print(next(iter(dataset)))  # get the first line | ||||
|  | @ -0,0 +1,100 @@ | |||
| from functools import cache | ||||
| from typing import Dict, Optional | ||||
| 
 | ||||
| import pytorch_lightning as pl | ||||
| import torch | ||||
| import torchmetrics | ||||
| 
 | ||||
| # from utils import init_model | ||||
| # from custom_models.gpt2.modeling_gpt2 import GPT2LMHeadModel | ||||
| 
 | ||||
| from modeling_wit import QWenLMHeadModel | ||||
| from configuration_qwen import QWenConfig | ||||
| 
 | ||||
| from transformers import AutoConfig | ||||
| 
 | ||||
| 
 | ||||
| class LitModule(pl.LightningModule): | ||||
|     def __init__( | ||||
|         self, | ||||
|         model_dir: str, | ||||
|         learning_rate: float = 0.0001, | ||||
|         use_tril_attention_mask: str = False, | ||||
|     ): | ||||
|         super().__init__() | ||||
|         self.save_hyperparameters() | ||||
|         config = QWenConfig() | ||||
|         model = QWenLMHeadModel(config) | ||||
|         model = model.from_pretrained(model_dir) | ||||
|         self.llm = self.register_core_module(model) | ||||
|         self.learning_rate = learning_rate | ||||
|         self.use_tril_attention_mask = use_tril_attention_mask | ||||
|         self.metric_loss = torchmetrics.MeanMetric() | ||||
|         self.metric_accuracy = torchmetrics.Accuracy( | ||||
|             task="multiclass", | ||||
|             num_classes=self.llm.config.vocab_size, | ||||
|         ) | ||||
| 
 | ||||
|     @cache | ||||
|     def get_batch_tril_matrix(self, block_size: int, batch_size: Optional[int] = None) -> torch.Tensor: | ||||
|         matrix = torch.ones(block_size, block_size).tril() | ||||
|         if batch_size is not None: | ||||
|             matrix = matrix.repeat(batch_size, 1, 1) | ||||
|         return matrix | ||||
| 
 | ||||
|     def register_core_module(self, module: torch.nn.Module) -> torch.nn.Module: | ||||
|         object.__setattr__(self, "__core_module__", module) | ||||
|         return module | ||||
| 
 | ||||
|     def training_step(self, batch: Dict[str, torch.Tensor], batch_idx): | ||||
|         batch_size, block_size = batch["input_ids"].shape | ||||
|         if self.use_tril_attention_mask: | ||||
|             batch["attention_mask"] = self.get_batch_tril_matrix(block_size, batch_size=batch_size).to(self.device) | ||||
|         outputs, loss = self.llm(**batch) | ||||
|         self.log("train_loss", loss, rank_zero_only=True) | ||||
|         return loss | ||||
| 
 | ||||
|     def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx): | ||||
|         outputs, loss = self.llm(**batch, return_dict=True) | ||||
|         logits = outputs[..., :-1, :] | ||||
|         labels = batch["labels"][..., 1:] | ||||
| 
 | ||||
|         self.metric_loss.update(loss) | ||||
| 
 | ||||
|         label_mask = labels != -100 | ||||
|         self.metric_accuracy.update(logits[label_mask], labels[label_mask]) | ||||
| 
 | ||||
|     def on_validation_epoch_end(self) -> None: | ||||
|         self.log("val_loss", self.metric_loss, rank_zero_only=True) | ||||
|         self.log("accuracy", self.metric_accuracy, rank_zero_only=True) | ||||
| 
 | ||||
|     def configure_optimizers(self): | ||||
|         strategy = self.trainer.strategy | ||||
|         if isinstance(strategy, pl.strategies.DeepSpeedStrategy): | ||||
|             assert "optimizer" not in strategy.config | ||||
|             zero_config = strategy.config.get("zero_optimization") | ||||
|             if zero_config is not None: | ||||
|                 if "offload_optimizer" in zero_config: | ||||
|                     import deepspeed | ||||
| 
 | ||||
|                     optimizer = deepspeed.ops.adam.DeepSpeedCPUAdam( | ||||
|                         self.trainer.model.parameters(), lr=self.learning_rate | ||||
|                     ) | ||||
|                     return optimizer | ||||
|         optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=self.learning_rate) | ||||
|         return optimizer | ||||
| 
 | ||||
|     def configure_callbacks(self): | ||||
|         checkpoint_callback = pl.callbacks.ModelCheckpoint( | ||||
|             monitor="accuracy", | ||||
|             mode="max", | ||||
|             filename="{epoch:02d}-{accuracy:.4f}", | ||||
|         ) | ||||
|         early_stop_callback = pl.callbacks.EarlyStopping( | ||||
|             monitor="accuracy", | ||||
|             min_delta=0.001, | ||||
|             patience=3, | ||||
|             mode="max", | ||||
|             stopping_threshold=1, | ||||
|         ) | ||||
|         return [checkpoint_callback, early_stop_callback] | ||||
|  | @ -0,0 +1,164 @@ | |||
| import argparse | ||||
| from functools import partial | ||||
| from itertools import chain | ||||
| from typing import Dict, Tuple | ||||
| 
 | ||||
| import datasets | ||||
| import pytorch_lightning as pl | ||||
| import torch | ||||
| from torch.utils.data import ConcatDataset, DataLoader | ||||
| from transformers import ( | ||||
|     BatchEncoding, | ||||
|     DefaultDataCollator, | ||||
|     PreTrainedTokenizer, | ||||
|     set_seed, | ||||
| ) | ||||
| from modelscope import snapshot_download | ||||
| from lit_module import LitModule | ||||
| from tokenization_qwen import QWenTokenizer | ||||
| 
 | ||||
| model_name = "qwen/Qwen-1_8B-Chat" | ||||
| learning_rate = 0.0001 | ||||
| use_tril_attention_mask = None | ||||
| precision = "16-mixed"  # "precision:bf16-mixed,16-mixed,32-true" | ||||
| tokenizer_name_or_path = None | ||||
| dataset_name = "/home/colin/develop/dataset/liwu/MNBVC/wiki/20230197/0.jsonl.gz" | ||||
| dataset_name = "/home/colin/develop/dataset/liwu/MNBVC/wiki" | ||||
| train_batch_size = 8 | ||||
| val_batch_size = 1 | ||||
| accumulate_grad_batches = 32 | ||||
| num_proc = 8 | ||||
| max_epochs = None | ||||
| strategy = "fsdp" | ||||
| resume_from_ckpt_path = None | ||||
| seed = 42 | ||||
| 
 | ||||
| 
 | ||||
| def split_raw_dataset( | ||||
|     raw_dataset: datasets.DatasetDict, | ||||
| ) -> Tuple[datasets.Dataset, datasets.Dataset]: | ||||
|     if "validation" in raw_dataset: | ||||
|         train_dataset, val_dataset = raw_dataset["train"], raw_dataset["validation"] | ||||
|     else: | ||||
|         raw_dataset = raw_dataset["train"].train_test_split(test_size=0.05, seed=seed) | ||||
|         train_dataset, val_dataset = raw_dataset["train"], raw_dataset["test"] | ||||
|     return train_dataset, val_dataset | ||||
| 
 | ||||
| 
 | ||||
| def process_dataset(dataset: datasets.Dataset, tokenizer: PreTrainedTokenizer) -> datasets.Dataset: | ||||
|     def group_texts(examples: Dict[str, list], block_size: int = 512) -> BatchEncoding: | ||||
|         concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} | ||||
|         total_length = len(concatenated_examples[list(examples.keys())[0]]) | ||||
|         total_length = (total_length // block_size) * block_size | ||||
|         result = { | ||||
|             k: [t[i : i + block_size] for i in range(0, total_length, block_size)] | ||||
|             for k, t in concatenated_examples.items() | ||||
|         } | ||||
|         result["labels"] = result["input_ids"].copy() | ||||
|         result = BatchEncoding(result) | ||||
|         return result | ||||
| 
 | ||||
|     def format_inputs(examples): | ||||
|         p = examples["段落"] | ||||
|         mergeLine = "" | ||||
|         for line in p: | ||||
|             mergeLine += line["内容"] + "\n" | ||||
|         return {"text": mergeLine} | ||||
| 
 | ||||
|     def tokenize_inputs( | ||||
|         examples: Dict[str, list], | ||||
|         tokenizer: PreTrainedTokenizer, | ||||
|         column_name: str = "text", | ||||
|     ) -> BatchEncoding: | ||||
|         logits = tokenizer(examples[column_name], return_attention_mask=False) | ||||
|         return logits | ||||
| 
 | ||||
|     dataset_column_names = list(dataset.features) | ||||
|     dataset = dataset.map( | ||||
|         partial(format_inputs), | ||||
|         batched=False, | ||||
|         num_proc=num_proc, | ||||
|         remove_columns=dataset_column_names, | ||||
|     ) | ||||
|     dataset_column_names = list(dataset.features) | ||||
|     dataset = dataset.map( | ||||
|         partial(tokenize_inputs, tokenizer=tokenizer), | ||||
|         batched=True, | ||||
|         num_proc=num_proc, | ||||
|         remove_columns=dataset_column_names, | ||||
|     ) | ||||
|     dataset = dataset.map( | ||||
|         partial(group_texts, block_size=tokenizer.model_max_length), | ||||
|         batched=True, | ||||
|         num_proc=num_proc, | ||||
|     ) | ||||
| 
 | ||||
|     return dataset | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     if tokenizer_name_or_path is None: | ||||
|         tokenizer_name_or_path = model_name | ||||
| 
 | ||||
|     set_seed(seed) | ||||
| 
 | ||||
|     # lightning module | ||||
|     model_dir = snapshot_download(model_name) | ||||
|     lit_module = LitModule(model_dir, learning_rate, use_tril_attention_mask) | ||||
| 
 | ||||
|     # datasets | ||||
|     # tokenizer = load_tokenizer("./custom_models/gpt2") | ||||
|     tokenizer = QWenTokenizer("./wit_b64.tiktoken", "./wit_char.tiktoken") | ||||
|     train_dataset_list = [] | ||||
|     val_dataset_list = [] | ||||
|     for dataset_name in dataset_name: | ||||
|         dataset_args = dataset_name.split(":") | ||||
|         raw_dataset = datasets.load_dataset( | ||||
|             "json", data_files="/home/colin/develop/dataset/liwu/MNBVC/wiki/20230197/0.jsonl.gz" | ||||
|         ) | ||||
|         # raw_dataset = datasets.load_dataset(*dataset_args) | ||||
|         train_dataset, val_dataset = split_raw_dataset(raw_dataset) | ||||
|         train_dataset = process_dataset(train_dataset, tokenizer) | ||||
|         val_dataset = process_dataset(val_dataset, tokenizer) | ||||
|         train_dataset_list.append(train_dataset) | ||||
|         val_dataset_list.append(val_dataset) | ||||
|     train_dataset = ConcatDataset(train_dataset_list) | ||||
|     val_dataset = ConcatDataset(val_dataset_list) | ||||
| 
 | ||||
|     # dataloaders | ||||
|     train_dataloader = DataLoader( | ||||
|         train_dataset, | ||||
|         batch_size=train_batch_size, | ||||
|         num_workers=num_proc, | ||||
|         collate_fn=DefaultDataCollator(), | ||||
|         persistent_workers=True, | ||||
|         shuffle=True, | ||||
|     ) | ||||
|     val_dataloader = DataLoader( | ||||
|         val_dataset, | ||||
|         batch_size=val_batch_size, | ||||
|         num_workers=num_proc, | ||||
|         collate_fn=DefaultDataCollator(), | ||||
|         persistent_workers=True, | ||||
|     ) | ||||
| 
 | ||||
|     ne = next(train_dataloader._get_iterator()) | ||||
| 
 | ||||
|     # trainer | ||||
|     # apply_all_patches() | ||||
|     torch.set_float32_matmul_precision("medium") | ||||
|     precision = precision | ||||
|     lit_trainer = pl.Trainer( | ||||
|         accelerator="gpu", | ||||
|         precision=precision, | ||||
|         log_every_n_steps=5, | ||||
|         accumulate_grad_batches=accumulate_grad_batches, | ||||
|         strategy=strategy, | ||||
|         max_epochs=max_epochs, | ||||
|     ) | ||||
|     lit_trainer.fit( | ||||
|         lit_module, | ||||
|         train_dataloaders=train_dataloader, | ||||
|         val_dataloaders=val_dataloader, | ||||
|         ckpt_path=resume_from_ckpt_path, | ||||
|     ) | ||||
|  | @ -137,6 +137,16 @@ class QWenLMHeadModel(nn.Module): | |||
|         self.transformer = QWenModel(config) | ||||
|         self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | ||||
| 
 | ||||
|     def forward( | ||||
|         self, | ||||
|         input_ids: Optional[torch.LongTensor] = None, | ||||
|         labels: Optional[torch.LongTensor] = None, | ||||
|         token_type_ids: Optional[torch.LongTensor] = None, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         runner = QwenRunner(self) | ||||
|         return runner.forwardQWen(input_ids, labels) | ||||
| 
 | ||||
|     def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]]): | ||||
|         pretrained_model_name_or_path = str(pretrained_model_name_or_path) | ||||
|         resolved_archive_file = os.path.join(pretrained_model_name_or_path, "model.safetensors.index.json") | ||||
|  | @ -343,15 +353,7 @@ class QwenRunner: | |||
|             loss_fct = CrossEntropyLoss() | ||||
|             loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) | ||||
| 
 | ||||
|         # shift_labels = torch.ones([1,19]).to(lm_logits.device).to(torch.int64) | ||||
|         # shift_logits = lm_logits[..., :-1, :].contiguous() | ||||
|         # loss_fct = CrossEntropyLoss() | ||||
|         # loss = loss_fct( | ||||
|         #     shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) | ||||
|         # ) | ||||
|         # loss.backward() | ||||
| 
 | ||||
|         return lm_logits | ||||
|         return lm_logits, loss | ||||
| 
 | ||||
|     def prepareInput(self, tokenizer, query, query_assistant, history, system): | ||||
|         return make_context(tokenizer, query, query_assistant, history=history, system=system) | ||||
|  |  | |||
|  | @ -63,7 +63,7 @@ class QWenTokenizer(PreTrainedTokenizer): | |||
| 
 | ||||
|         self.mergeable_ranks = _load_tiktoken_b64(vocab_file_b64) | ||||
|         self.mergeable_ranks.update(_load_tiktoken_char(vocab_file_char, len(self.mergeable_ranks))) | ||||
| 
 | ||||
|         self.model_max_length = 1024 | ||||
|         special = ( | ||||
|             "user", | ||||
|             "assistant", | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue