From 9e8e92ae25f51133f946abe7d9c0da2688227cce Mon Sep 17 00:00:00 2001 From: Colin Date: Mon, 4 Mar 2024 21:41:46 +0800 Subject: [PATCH] Update trainer to custom data. --- wit/configuration_qwen.py | 5 +- wit/lit_module.py | 2 +- wit/lit_train.py | 101 +++++--------------------------------- wit/modeling_wit.py | 20 ++++---- 4 files changed, 26 insertions(+), 102 deletions(-) diff --git a/wit/configuration_qwen.py b/wit/configuration_qwen.py index dd95d4e..e1725b1 100644 --- a/wit/configuration_qwen.py +++ b/wit/configuration_qwen.py @@ -7,8 +7,8 @@ class QWenConfig: def __init__(self): self.vocab_size = 4096 - self.hidden_size = 1024 # 1024 2048 - self.num_hidden_layers = 12 # 12 24 + self.hidden_size = 128 # 1024 2048 + self.num_hidden_layers = 6 # 12 24 self.num_attention_heads = 8 # 8 16 self.emb_dropout_prob = 0.0 self.attn_dropout_prob = 0.0 @@ -20,7 +20,6 @@ class QWenConfig: self.bf16 = False self.fp16 = False self.fp32 = False - self.kv_channels = 128 self.rotary_pct = 1.0 self.rotary_emb_base = 10000 self.use_dynamic_ntk = True diff --git a/wit/lit_module.py b/wit/lit_module.py index be86bf1..055b2aa 100644 --- a/wit/lit_module.py +++ b/wit/lit_module.py @@ -61,7 +61,7 @@ class LitModule(pl.LightningModule): self.metric_loss.update(loss) - label_mask = labels != -100 + label_mask = labels != 0 self.metric_accuracy.update(logits[label_mask], labels[label_mask]) def on_validation_epoch_end(self) -> None: diff --git a/wit/lit_train.py b/wit/lit_train.py index 63f7296..a958a97 100644 --- a/wit/lit_train.py +++ b/wit/lit_train.py @@ -6,7 +6,8 @@ from typing import Dict, Tuple import datasets import pytorch_lightning as pl import torch -from torch.utils.data import ConcatDataset, DataLoader, Dataset +from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split, Subset + from transformers import ( BatchEncoding, DefaultDataCollator, @@ -22,8 +23,6 @@ learning_rate = 0.0001 use_tril_attention_mask = None precision = "32-true" # "precision:bf16-mixed,16-mixed,32-true" tokenizer_name_or_path = None -dataset_name = ["/home/colin/develop/dataset/liwu/MNBVC/wiki"] -dataset_name = ["/home/colin/develop/dataset/liwu/MNBVC/wiki/20230198/58.jsonl.gz"] train_batch_size = 256 val_batch_size = 16 num_proc = 8 @@ -34,11 +33,14 @@ seed = 42 class SpecialDataset(Dataset): - def __init__(self, start, end, size=65536): + def __init__(self, start=1, end=4096, size=65536): self.size = size self.features = [] a = torch.randint(start, end, [size]) - self.data = torch.stack([a, a * 2, a * 3, a * 4]).permute(1, 0) + b = torch.randint(start, end, [size]) + c = torch.randint(start, end, [size]) + d = torch.randint(start, end, [size]) + self.data = torch.stack([a, b, c, d, ((a + b + c + d) / 4).long()]).permute(1, 0) def __len__(self): return self.size @@ -47,73 +49,12 @@ class SpecialDataset(Dataset): output = {} data = self.data[idx] output["input_ids"] = data - output["labels"] = data + output["labels"] = data.clone() + output["labels"][:4] = 0 output["token_type_ids"] = torch.zeros(data.shape) return output -def split_raw_dataset( - raw_dataset: datasets.DatasetDict, -) -> Tuple[datasets.Dataset, datasets.Dataset]: - if "validation" in raw_dataset: - train_dataset, val_dataset = raw_dataset["train"], raw_dataset["validation"] - else: - raw_dataset = raw_dataset["train"].train_test_split(test_size=0.05, seed=seed) - train_dataset, val_dataset = raw_dataset["train"], raw_dataset["test"] - return train_dataset, val_dataset - - -def process_dataset(dataset: datasets.Dataset, tokenizer: PreTrainedTokenizer) -> datasets.Dataset: - def group_texts(examples: Dict[str, list], block_size: int = 512) -> BatchEncoding: - concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} - total_length = len(concatenated_examples[list(examples.keys())[0]]) - total_length = (total_length // block_size) * block_size - result = { - k: [t[i : i + block_size] for i in range(0, total_length, block_size)] - for k, t in concatenated_examples.items() - } - result["labels"] = result["input_ids"].copy() - result = BatchEncoding(result) - return result - - def format_inputs(examples): - p = examples["段落"] - mergeLine = "" - for line in p: - mergeLine += line["内容"] + "\n" - return {"text": mergeLine} - - def tokenize_inputs( - examples: Dict[str, list], - tokenizer: PreTrainedTokenizer, - column_name: str = "text", - ) -> BatchEncoding: - logits = tokenizer(examples[column_name], return_attention_mask=False) - return logits - - dataset_column_names = list(dataset.features) - dataset = dataset.map( - partial(format_inputs), - batched=False, - num_proc=num_proc, - remove_columns=dataset_column_names, - ) - dataset_column_names = list(dataset.features) - dataset = dataset.map( - partial(tokenize_inputs, tokenizer=tokenizer), - batched=True, - num_proc=num_proc, - remove_columns=dataset_column_names, - ) - dataset = dataset.map( - partial(group_texts, block_size=tokenizer.model_max_length), - batched=True, - num_proc=num_proc, - ) - - return dataset - - if __name__ == "__main__": if tokenizer_name_or_path is None: tokenizer_name_or_path = model_name @@ -125,26 +66,11 @@ if __name__ == "__main__": lit_module = LitModule(model_dir, learning_rate, use_tril_attention_mask) tokenizer = QWenTokenizer("./wit_b64.tiktoken", "./wit_char.tiktoken") - train_dataset_list = [] - val_dataset_list = [] - for dn in dataset_name: - datanames = dn.split(".") - if datanames[-1] == "gz" and datanames[-2] == "jsonl": - raw_dataset = datasets.load_dataset("json", data_files=dn) - elif datanames[-1] == "json": - raw_dataset = datasets.load_dataset("json", data_files=dn) - else: - raw_dataset = datasets.load_dataset(dn) - train_dataset, val_dataset = split_raw_dataset(raw_dataset) - train_dataset = process_dataset(train_dataset, tokenizer) - val_dataset = process_dataset(val_dataset, tokenizer) - train_dataset_list.append(train_dataset) - val_dataset_list.append(val_dataset) - train_dataset = ConcatDataset(train_dataset_list) - val_dataset = ConcatDataset(val_dataset_list) - train_dataset = SpecialDataset(0, 1000, 65536) - val_dataset = SpecialDataset(1000, 1024, 1024) + raw_dataset = SpecialDataset() + train_idx, val_idx = random_split(list(range(len(raw_dataset))), [0.95, 0.05]) + train_dataset = Subset(raw_dataset, train_idx.indices) + val_dataset = Subset(raw_dataset, val_idx.indices) train_dataloader = DataLoader( train_dataset, @@ -160,7 +86,6 @@ if __name__ == "__main__": num_workers=num_proc, collate_fn=DefaultDataCollator(), persistent_workers=True, - shuffle=True, ) torch.set_float32_matmul_precision("medium") diff --git a/wit/modeling_wit.py b/wit/modeling_wit.py index f01a725..1122362 100644 --- a/wit/modeling_wit.py +++ b/wit/modeling_wit.py @@ -49,9 +49,8 @@ class QWenAttention(nn.Module): self.split_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads - self.projection_size = config.kv_channels * config.num_attention_heads - self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size) - self.c_proj = nn.Linear(config.hidden_size, self.projection_size, bias=not config.no_bias) + self.c_attn = nn.Linear(config.hidden_size, 3 * self.hidden_size) + self.c_proj = nn.Linear(config.hidden_size, self.hidden_size, bias=not config.no_bias) self.attn_dropout = nn.Dropout(config.attn_dropout_prob) self.index = index @@ -96,17 +95,15 @@ class QWenModel(nn.Module): super().__init__() self.wte = nn.Embedding(config.vocab_size, config.hidden_size) self.drop = nn.Dropout(config.emb_dropout_prob) - dim = config.kv_channels + self.dim = config.hidden_size // config.num_attention_heads self.h = nn.ModuleList([QWenBlock(config, i) for i in range(config.num_hidden_layers)]) self.ln_f = RMSNorm( config.hidden_size, eps=config.layer_norm_epsilon, ) - - self.dim = dim self.base = config.rotary_emb_base - inv_freq = 1.0 / (self.base ** (torch.arange(0, dim, 2).float() / dim)) + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self._rotary_pos_emb_cache = None self._seq_len_cached = 0 @@ -348,11 +345,14 @@ class QwenRunner: loss = None if labels is not None: labels = labels.to(lm_logits.device) + shift_labels = labels[..., 1:].contiguous().view(-1) shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() + shift_logits = shift_logits.view(-1, shift_logits.size(-1)) + mask = shift_labels != 0 + shift_labels = shift_labels[mask] + shift_logits = shift_logits[mask] loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - + loss = loss_fct(shift_logits, shift_labels) return lm_logits, loss def prepareInput(self, tokenizer, query, query_assistant, history, system):