Add wit train support.
This commit is contained in:
		
							parent
							
								
									fc071dce70
								
							
						
					
					
						commit
						e5f97af291
					
				|  | @ -2,3 +2,4 @@ __pycache__ | ||||||
| .vscode | .vscode | ||||||
| *.txt | *.txt | ||||||
| temp | temp | ||||||
|  | lightning_logs | ||||||
|  | @ -1,5 +0,0 @@ | ||||||
| from datasets import load_dataset |  | ||||||
| 
 |  | ||||||
| dataset = load_dataset("liwu/MNBVC", "wikipedia", split="train", streaming=True) |  | ||||||
| 
 |  | ||||||
| print(next(iter(dataset)))  # get the first line |  | ||||||
|  | @ -0,0 +1,100 @@ | ||||||
|  | from functools import cache | ||||||
|  | from typing import Dict, Optional | ||||||
|  | 
 | ||||||
|  | import pytorch_lightning as pl | ||||||
|  | import torch | ||||||
|  | import torchmetrics | ||||||
|  | 
 | ||||||
|  | # from utils import init_model | ||||||
|  | # from custom_models.gpt2.modeling_gpt2 import GPT2LMHeadModel | ||||||
|  | 
 | ||||||
|  | from modeling_wit import QWenLMHeadModel | ||||||
|  | from configuration_qwen import QWenConfig | ||||||
|  | 
 | ||||||
|  | from transformers import AutoConfig | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class LitModule(pl.LightningModule): | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         model_dir: str, | ||||||
|  |         learning_rate: float = 0.0001, | ||||||
|  |         use_tril_attention_mask: str = False, | ||||||
|  |     ): | ||||||
|  |         super().__init__() | ||||||
|  |         self.save_hyperparameters() | ||||||
|  |         config = QWenConfig() | ||||||
|  |         model = QWenLMHeadModel(config) | ||||||
|  |         model = model.from_pretrained(model_dir) | ||||||
|  |         self.llm = self.register_core_module(model) | ||||||
|  |         self.learning_rate = learning_rate | ||||||
|  |         self.use_tril_attention_mask = use_tril_attention_mask | ||||||
|  |         self.metric_loss = torchmetrics.MeanMetric() | ||||||
|  |         self.metric_accuracy = torchmetrics.Accuracy( | ||||||
|  |             task="multiclass", | ||||||
|  |             num_classes=self.llm.config.vocab_size, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |     @cache | ||||||
|  |     def get_batch_tril_matrix(self, block_size: int, batch_size: Optional[int] = None) -> torch.Tensor: | ||||||
|  |         matrix = torch.ones(block_size, block_size).tril() | ||||||
|  |         if batch_size is not None: | ||||||
|  |             matrix = matrix.repeat(batch_size, 1, 1) | ||||||
|  |         return matrix | ||||||
|  | 
 | ||||||
|  |     def register_core_module(self, module: torch.nn.Module) -> torch.nn.Module: | ||||||
|  |         object.__setattr__(self, "__core_module__", module) | ||||||
|  |         return module | ||||||
|  | 
 | ||||||
|  |     def training_step(self, batch: Dict[str, torch.Tensor], batch_idx): | ||||||
|  |         batch_size, block_size = batch["input_ids"].shape | ||||||
|  |         if self.use_tril_attention_mask: | ||||||
|  |             batch["attention_mask"] = self.get_batch_tril_matrix(block_size, batch_size=batch_size).to(self.device) | ||||||
|  |         outputs, loss = self.llm(**batch) | ||||||
|  |         self.log("train_loss", loss, rank_zero_only=True) | ||||||
|  |         return loss | ||||||
|  | 
 | ||||||
|  |     def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx): | ||||||
|  |         outputs, loss = self.llm(**batch, return_dict=True) | ||||||
|  |         logits = outputs[..., :-1, :] | ||||||
|  |         labels = batch["labels"][..., 1:] | ||||||
|  | 
 | ||||||
|  |         self.metric_loss.update(loss) | ||||||
|  | 
 | ||||||
|  |         label_mask = labels != -100 | ||||||
|  |         self.metric_accuracy.update(logits[label_mask], labels[label_mask]) | ||||||
|  | 
 | ||||||
|  |     def on_validation_epoch_end(self) -> None: | ||||||
|  |         self.log("val_loss", self.metric_loss, rank_zero_only=True) | ||||||
|  |         self.log("accuracy", self.metric_accuracy, rank_zero_only=True) | ||||||
|  | 
 | ||||||
|  |     def configure_optimizers(self): | ||||||
|  |         strategy = self.trainer.strategy | ||||||
|  |         if isinstance(strategy, pl.strategies.DeepSpeedStrategy): | ||||||
|  |             assert "optimizer" not in strategy.config | ||||||
|  |             zero_config = strategy.config.get("zero_optimization") | ||||||
|  |             if zero_config is not None: | ||||||
|  |                 if "offload_optimizer" in zero_config: | ||||||
|  |                     import deepspeed | ||||||
|  | 
 | ||||||
|  |                     optimizer = deepspeed.ops.adam.DeepSpeedCPUAdam( | ||||||
|  |                         self.trainer.model.parameters(), lr=self.learning_rate | ||||||
|  |                     ) | ||||||
|  |                     return optimizer | ||||||
|  |         optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=self.learning_rate) | ||||||
|  |         return optimizer | ||||||
|  | 
 | ||||||
|  |     def configure_callbacks(self): | ||||||
|  |         checkpoint_callback = pl.callbacks.ModelCheckpoint( | ||||||
|  |             monitor="accuracy", | ||||||
|  |             mode="max", | ||||||
|  |             filename="{epoch:02d}-{accuracy:.4f}", | ||||||
|  |         ) | ||||||
|  |         early_stop_callback = pl.callbacks.EarlyStopping( | ||||||
|  |             monitor="accuracy", | ||||||
|  |             min_delta=0.001, | ||||||
|  |             patience=3, | ||||||
|  |             mode="max", | ||||||
|  |             stopping_threshold=1, | ||||||
|  |         ) | ||||||
|  |         return [checkpoint_callback, early_stop_callback] | ||||||
|  | @ -0,0 +1,164 @@ | ||||||
|  | import argparse | ||||||
|  | from functools import partial | ||||||
|  | from itertools import chain | ||||||
|  | from typing import Dict, Tuple | ||||||
|  | 
 | ||||||
|  | import datasets | ||||||
|  | import pytorch_lightning as pl | ||||||
|  | import torch | ||||||
|  | from torch.utils.data import ConcatDataset, DataLoader | ||||||
|  | from transformers import ( | ||||||
|  |     BatchEncoding, | ||||||
|  |     DefaultDataCollator, | ||||||
|  |     PreTrainedTokenizer, | ||||||
|  |     set_seed, | ||||||
|  | ) | ||||||
|  | from modelscope import snapshot_download | ||||||
|  | from lit_module import LitModule | ||||||
|  | from tokenization_qwen import QWenTokenizer | ||||||
|  | 
 | ||||||
|  | model_name = "qwen/Qwen-1_8B-Chat" | ||||||
|  | learning_rate = 0.0001 | ||||||
|  | use_tril_attention_mask = None | ||||||
|  | precision = "16-mixed"  # "precision:bf16-mixed,16-mixed,32-true" | ||||||
|  | tokenizer_name_or_path = None | ||||||
|  | dataset_name = "/home/colin/develop/dataset/liwu/MNBVC/wiki/20230197/0.jsonl.gz" | ||||||
|  | dataset_name = "/home/colin/develop/dataset/liwu/MNBVC/wiki" | ||||||
|  | train_batch_size = 8 | ||||||
|  | val_batch_size = 1 | ||||||
|  | accumulate_grad_batches = 32 | ||||||
|  | num_proc = 8 | ||||||
|  | max_epochs = None | ||||||
|  | strategy = "fsdp" | ||||||
|  | resume_from_ckpt_path = None | ||||||
|  | seed = 42 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def split_raw_dataset( | ||||||
|  |     raw_dataset: datasets.DatasetDict, | ||||||
|  | ) -> Tuple[datasets.Dataset, datasets.Dataset]: | ||||||
|  |     if "validation" in raw_dataset: | ||||||
|  |         train_dataset, val_dataset = raw_dataset["train"], raw_dataset["validation"] | ||||||
|  |     else: | ||||||
|  |         raw_dataset = raw_dataset["train"].train_test_split(test_size=0.05, seed=seed) | ||||||
|  |         train_dataset, val_dataset = raw_dataset["train"], raw_dataset["test"] | ||||||
|  |     return train_dataset, val_dataset | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def process_dataset(dataset: datasets.Dataset, tokenizer: PreTrainedTokenizer) -> datasets.Dataset: | ||||||
|  |     def group_texts(examples: Dict[str, list], block_size: int = 512) -> BatchEncoding: | ||||||
|  |         concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} | ||||||
|  |         total_length = len(concatenated_examples[list(examples.keys())[0]]) | ||||||
|  |         total_length = (total_length // block_size) * block_size | ||||||
|  |         result = { | ||||||
|  |             k: [t[i : i + block_size] for i in range(0, total_length, block_size)] | ||||||
|  |             for k, t in concatenated_examples.items() | ||||||
|  |         } | ||||||
|  |         result["labels"] = result["input_ids"].copy() | ||||||
|  |         result = BatchEncoding(result) | ||||||
|  |         return result | ||||||
|  | 
 | ||||||
|  |     def format_inputs(examples): | ||||||
|  |         p = examples["段落"] | ||||||
|  |         mergeLine = "" | ||||||
|  |         for line in p: | ||||||
|  |             mergeLine += line["内容"] + "\n" | ||||||
|  |         return {"text": mergeLine} | ||||||
|  | 
 | ||||||
|  |     def tokenize_inputs( | ||||||
|  |         examples: Dict[str, list], | ||||||
|  |         tokenizer: PreTrainedTokenizer, | ||||||
|  |         column_name: str = "text", | ||||||
|  |     ) -> BatchEncoding: | ||||||
|  |         logits = tokenizer(examples[column_name], return_attention_mask=False) | ||||||
|  |         return logits | ||||||
|  | 
 | ||||||
|  |     dataset_column_names = list(dataset.features) | ||||||
|  |     dataset = dataset.map( | ||||||
|  |         partial(format_inputs), | ||||||
|  |         batched=False, | ||||||
|  |         num_proc=num_proc, | ||||||
|  |         remove_columns=dataset_column_names, | ||||||
|  |     ) | ||||||
|  |     dataset_column_names = list(dataset.features) | ||||||
|  |     dataset = dataset.map( | ||||||
|  |         partial(tokenize_inputs, tokenizer=tokenizer), | ||||||
|  |         batched=True, | ||||||
|  |         num_proc=num_proc, | ||||||
|  |         remove_columns=dataset_column_names, | ||||||
|  |     ) | ||||||
|  |     dataset = dataset.map( | ||||||
|  |         partial(group_texts, block_size=tokenizer.model_max_length), | ||||||
|  |         batched=True, | ||||||
|  |         num_proc=num_proc, | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     return dataset | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     if tokenizer_name_or_path is None: | ||||||
|  |         tokenizer_name_or_path = model_name | ||||||
|  | 
 | ||||||
|  |     set_seed(seed) | ||||||
|  | 
 | ||||||
|  |     # lightning module | ||||||
|  |     model_dir = snapshot_download(model_name) | ||||||
|  |     lit_module = LitModule(model_dir, learning_rate, use_tril_attention_mask) | ||||||
|  | 
 | ||||||
|  |     # datasets | ||||||
|  |     # tokenizer = load_tokenizer("./custom_models/gpt2") | ||||||
|  |     tokenizer = QWenTokenizer("./wit_b64.tiktoken", "./wit_char.tiktoken") | ||||||
|  |     train_dataset_list = [] | ||||||
|  |     val_dataset_list = [] | ||||||
|  |     for dataset_name in dataset_name: | ||||||
|  |         dataset_args = dataset_name.split(":") | ||||||
|  |         raw_dataset = datasets.load_dataset( | ||||||
|  |             "json", data_files="/home/colin/develop/dataset/liwu/MNBVC/wiki/20230197/0.jsonl.gz" | ||||||
|  |         ) | ||||||
|  |         # raw_dataset = datasets.load_dataset(*dataset_args) | ||||||
|  |         train_dataset, val_dataset = split_raw_dataset(raw_dataset) | ||||||
|  |         train_dataset = process_dataset(train_dataset, tokenizer) | ||||||
|  |         val_dataset = process_dataset(val_dataset, tokenizer) | ||||||
|  |         train_dataset_list.append(train_dataset) | ||||||
|  |         val_dataset_list.append(val_dataset) | ||||||
|  |     train_dataset = ConcatDataset(train_dataset_list) | ||||||
|  |     val_dataset = ConcatDataset(val_dataset_list) | ||||||
|  | 
 | ||||||
|  |     # dataloaders | ||||||
|  |     train_dataloader = DataLoader( | ||||||
|  |         train_dataset, | ||||||
|  |         batch_size=train_batch_size, | ||||||
|  |         num_workers=num_proc, | ||||||
|  |         collate_fn=DefaultDataCollator(), | ||||||
|  |         persistent_workers=True, | ||||||
|  |         shuffle=True, | ||||||
|  |     ) | ||||||
|  |     val_dataloader = DataLoader( | ||||||
|  |         val_dataset, | ||||||
|  |         batch_size=val_batch_size, | ||||||
|  |         num_workers=num_proc, | ||||||
|  |         collate_fn=DefaultDataCollator(), | ||||||
|  |         persistent_workers=True, | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     ne = next(train_dataloader._get_iterator()) | ||||||
|  | 
 | ||||||
|  |     # trainer | ||||||
|  |     # apply_all_patches() | ||||||
|  |     torch.set_float32_matmul_precision("medium") | ||||||
|  |     precision = precision | ||||||
|  |     lit_trainer = pl.Trainer( | ||||||
|  |         accelerator="gpu", | ||||||
|  |         precision=precision, | ||||||
|  |         log_every_n_steps=5, | ||||||
|  |         accumulate_grad_batches=accumulate_grad_batches, | ||||||
|  |         strategy=strategy, | ||||||
|  |         max_epochs=max_epochs, | ||||||
|  |     ) | ||||||
|  |     lit_trainer.fit( | ||||||
|  |         lit_module, | ||||||
|  |         train_dataloaders=train_dataloader, | ||||||
|  |         val_dataloaders=val_dataloader, | ||||||
|  |         ckpt_path=resume_from_ckpt_path, | ||||||
|  |     ) | ||||||
|  | @ -137,6 +137,16 @@ class QWenLMHeadModel(nn.Module): | ||||||
|         self.transformer = QWenModel(config) |         self.transformer = QWenModel(config) | ||||||
|         self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |         self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | ||||||
| 
 | 
 | ||||||
|  |     def forward( | ||||||
|  |         self, | ||||||
|  |         input_ids: Optional[torch.LongTensor] = None, | ||||||
|  |         labels: Optional[torch.LongTensor] = None, | ||||||
|  |         token_type_ids: Optional[torch.LongTensor] = None, | ||||||
|  |         **kwargs, | ||||||
|  |     ): | ||||||
|  |         runner = QwenRunner(self) | ||||||
|  |         return runner.forwardQWen(input_ids, labels) | ||||||
|  | 
 | ||||||
|     def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]]): |     def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]]): | ||||||
|         pretrained_model_name_or_path = str(pretrained_model_name_or_path) |         pretrained_model_name_or_path = str(pretrained_model_name_or_path) | ||||||
|         resolved_archive_file = os.path.join(pretrained_model_name_or_path, "model.safetensors.index.json") |         resolved_archive_file = os.path.join(pretrained_model_name_or_path, "model.safetensors.index.json") | ||||||
|  | @ -343,15 +353,7 @@ class QwenRunner: | ||||||
|             loss_fct = CrossEntropyLoss() |             loss_fct = CrossEntropyLoss() | ||||||
|             loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) |             loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) | ||||||
| 
 | 
 | ||||||
|         # shift_labels = torch.ones([1,19]).to(lm_logits.device).to(torch.int64) |         return lm_logits, loss | ||||||
|         # shift_logits = lm_logits[..., :-1, :].contiguous() |  | ||||||
|         # loss_fct = CrossEntropyLoss() |  | ||||||
|         # loss = loss_fct( |  | ||||||
|         #     shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) |  | ||||||
|         # ) |  | ||||||
|         # loss.backward() |  | ||||||
| 
 |  | ||||||
|         return lm_logits |  | ||||||
| 
 | 
 | ||||||
|     def prepareInput(self, tokenizer, query, query_assistant, history, system): |     def prepareInput(self, tokenizer, query, query_assistant, history, system): | ||||||
|         return make_context(tokenizer, query, query_assistant, history=history, system=system) |         return make_context(tokenizer, query, query_assistant, history=history, system=system) | ||||||
|  |  | ||||||
|  | @ -63,7 +63,7 @@ class QWenTokenizer(PreTrainedTokenizer): | ||||||
| 
 | 
 | ||||||
|         self.mergeable_ranks = _load_tiktoken_b64(vocab_file_b64) |         self.mergeable_ranks = _load_tiktoken_b64(vocab_file_b64) | ||||||
|         self.mergeable_ranks.update(_load_tiktoken_char(vocab_file_char, len(self.mergeable_ranks))) |         self.mergeable_ranks.update(_load_tiktoken_char(vocab_file_char, len(self.mergeable_ranks))) | ||||||
| 
 |         self.model_max_length = 1024 | ||||||
|         special = ( |         special = ( | ||||||
|             "user", |             "user", | ||||||
|             "assistant", |             "assistant", | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue