From e635ce0df4ad02111582a23f24576d366833ecc4 Mon Sep 17 00:00:00 2001 From: Colin Date: Mon, 17 Feb 2025 19:41:40 +0800 Subject: [PATCH] Regine wit config method. --- .gitignore | 4 +- {wit => dataset}/MNBVC.py | 0 wit/__init__.py | 2 - wit/configuration.py | 38 ++- wit/dataset.py | 42 +++ wit/demo.py | 8 +- wit/inference.py | 3 +- wit/{ => model}/lit_module.py | 8 +- wit/{ => model}/model.safetensors.index.json | 0 wit/{ => model}/modeling_wit.py | 2 +- wit/model/qwen_generation_utils copy.py | 294 +++++++++++++++++++ wit/{ => model}/qwen_generation_utils.py | 0 wit/{ => model}/tokenization_qwen.py | 0 wit/{ => model}/tokenizer_config.json | 0 wit/{ => model}/wit_b64.tiktoken | 0 wit/{ => model}/wit_char.tiktoken | 0 wit/stress.py | 82 ------ wit/train.py | 82 ++---- wit/train_special.py | 79 ----- 19 files changed, 404 insertions(+), 240 deletions(-) rename {wit => dataset}/MNBVC.py (100%) create mode 100644 wit/dataset.py rename wit/{ => model}/lit_module.py (95%) rename wit/{ => model}/model.safetensors.index.json (100%) rename wit/{ => model}/modeling_wit.py (99%) create mode 100644 wit/model/qwen_generation_utils copy.py rename wit/{ => model}/qwen_generation_utils.py (100%) rename wit/{ => model}/tokenization_qwen.py (100%) rename wit/{ => model}/tokenizer_config.json (100%) rename wit/{ => model}/wit_b64.tiktoken (100%) rename wit/{ => model}/wit_char.tiktoken (100%) delete mode 100644 wit/stress.py delete mode 100644 wit/train_special.py diff --git a/.gitignore b/.gitignore index d4c498d..080d3c8 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,6 @@ checkpoints build log logs -data \ No newline at end of file +data + +mlruns \ No newline at end of file diff --git a/wit/MNBVC.py b/dataset/MNBVC.py similarity index 100% rename from wit/MNBVC.py rename to dataset/MNBVC.py diff --git a/wit/__init__.py b/wit/__init__.py index 59e4a5d..e69de29 100644 --- a/wit/__init__.py +++ b/wit/__init__.py @@ -1,2 +0,0 @@ -from qwen.modeling_qwen import QWenLMHeadModel -from qwen.configuration_qwen import QWenConfig \ No newline at end of file diff --git a/wit/configuration.py b/wit/configuration.py index ba89b7e..2eb6d09 100644 --- a/wit/configuration.py +++ b/wit/configuration.py @@ -1,7 +1,3 @@ -# Copyright (c) Alibaba Cloud. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. class ModelConfig: @@ -40,3 +36,37 @@ class ModelConfig: self.top_p = 0.8 self.repetition_penalty = 1.1 self.model_max_length = 8192 + + +class MeaningDatasetConfig: + def __init__(self): + self.level_ratio = 5 + self.level = 5 + self.dataset_level = 3 + self.min_subitem = 2 + self.mask_level = [0, 1, 2] + self.mask_idx = [0, 0, -1] + +class DatasetConfig: + def __init__(self): + self.name = "meaning" + self.meaning = MeaningDatasetConfig() + +class TrainConfig: + def __init__(self): + self.name = "bigger" # current train process name + self.pretrain_model_name = None # "qwen/Qwen-1_8B-Chat" + self.learning_rate = 0.0001 + self.use_tril_attention_mask = None + self.precision = "16-mixed" # "precision:bf16-mixed,16-mixed,32-true" + self.train_batch_size = 4 + self.val_batch_size = 4 + self.num_proc = 8 + self.max_epochs = 1000 + self.strategy = "auto" + self.resume_from_ckpt_path = None + self.seed = 42 + self.dataloader_works = 2 + + self.model_config = ModelConfig() + self.dataset = DatasetConfig() \ No newline at end of file diff --git a/wit/dataset.py b/wit/dataset.py new file mode 100644 index 0000000..a9effde --- /dev/null +++ b/wit/dataset.py @@ -0,0 +1,42 @@ +from meaning_dataset import MeaningDataset, BatchGroupMeaningDataloader +from special_dataset import SpecialDataset +from torch.utils.data import random_split, DataLoader + + +def InitDataset(config): + + train_batch_size = config.train_batch_size + val_batch_size = config.val_batch_size + num_proc = config.num_proc + if config.dataset.name == "special": + raw_dataset = SpecialDataset() + train_dataset, val_dataset = random_split(raw_dataset, [0.95, 0.05]) + train_dataloader = DataLoader( + train_dataset, + batch_size=train_batch_size, + num_workers=num_proc, + persistent_workers=True, + shuffle=True, + ) + val_dataloader = DataLoader( + val_dataset, + batch_size=val_batch_size, + num_workers=num_proc, + persistent_workers=True, + ) + return train_dataloader, val_dataloader + + if config.dataset.name == "meaning": + conf = config.dataset.meaning + vocab = config.model_config.vocab_size + start = vocab * (conf.level_ratio**conf.level) + size = vocab * int((conf.level_ratio**conf.dataset_level)) + raw_dataset = MeaningDataset(start, start + size, vocab, None, conf.level_ratio, conf.min_subitem) + # print(raw_dataset.token_frequency()) + raw_dataset.set_mask(conf.mask_level, conf.mask_idx) + train_dataset, val_dataset = raw_dataset.split(0.9) + train_dataloader = BatchGroupMeaningDataloader(train_dataset, train_batch_size).dataloader( + config.dataloader_works + ) + val_dataloader = BatchGroupMeaningDataloader(val_dataset, val_batch_size).dataloader(config.dataloader_works) + return train_dataloader, val_dataloader diff --git a/wit/demo.py b/wit/demo.py index c389725..d17e48f 100644 --- a/wit/demo.py +++ b/wit/demo.py @@ -2,13 +2,13 @@ import torch import sys from modelscope import snapshot_download -from modeling_wit import QWenLMHeadModel -from modeling_wit import QwenRunner +from wit.model.modeling_wit import QWenLMHeadModel +from wit.model.modeling_wit import QwenRunner from wit.configuration import ModelConfig -from tokenization_qwen import QWenTokenizer +from wit.model.tokenization_qwen import QWenTokenizer -from qwen_generation_utils import ( +from wit.model.qwen_generation_utils import ( make_context, decode_tokens, ) diff --git a/wit/inference.py b/wit/inference.py index 986f284..b947771 100644 --- a/wit/inference.py +++ b/wit/inference.py @@ -9,10 +9,9 @@ import torch from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split, Subset from lit_module import LitModule -from tokenization_qwen import QWenTokenizer +from wit.model.tokenization_qwen import QWenTokenizer from logger import TBLogger -from special_dataset import SpecialDataset from meaning_dataset import MeaningDataset, BatchGroupMeaningDataloader from wit.configuration import ModelConfig diff --git a/wit/lit_module.py b/wit/model/lit_module.py similarity index 95% rename from wit/lit_module.py rename to wit/model/lit_module.py index 8bfeab0..f654c4b 100644 --- a/wit/lit_module.py +++ b/wit/model/lit_module.py @@ -5,10 +5,8 @@ import pytorch_lightning as pl import torch import torchmetrics -from modeling_wit import QWenLMHeadModel -from wit.configuration import ModelConfig - -from transformers import AutoConfig +from model.modeling_wit import QWenLMHeadModel +from configuration import ModelConfig class LitModule(pl.LightningModule): @@ -63,7 +61,7 @@ class LitModule(pl.LightningModule): logits = logits.contiguous().view(-1, logits.size(-1)) labels = batch["labels"][..., 1:] labels = labels.contiguous().view(-1) - if batch["mask"] != None: + if "mask" in batch and batch["mask"] != None: label_mask = batch["mask"][..., 1:] label_mask = label_mask.contiguous().view(-1) logits = logits[label_mask] diff --git a/wit/model.safetensors.index.json b/wit/model/model.safetensors.index.json similarity index 100% rename from wit/model.safetensors.index.json rename to wit/model/model.safetensors.index.json diff --git a/wit/modeling_wit.py b/wit/model/modeling_wit.py similarity index 99% rename from wit/modeling_wit.py rename to wit/model/modeling_wit.py index 2a315dc..6281710 100644 --- a/wit/modeling_wit.py +++ b/wit/model/modeling_wit.py @@ -16,7 +16,7 @@ from torch import nn from safetensors.torch import load_file as safe_load_file from safetensors.torch import save_file as safe_save_file -from qwen_generation_utils import ( +from model.qwen_generation_utils import ( make_context, decode_tokens, ) diff --git a/wit/model/qwen_generation_utils copy.py b/wit/model/qwen_generation_utils copy.py new file mode 100644 index 0000000..cc80126 --- /dev/null +++ b/wit/model/qwen_generation_utils copy.py @@ -0,0 +1,294 @@ +# Copyright (c) Alibaba Cloud. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Generation support.""" + +from typing import Tuple, List, Union, Iterable + +import numpy as np +import torch +import torch.nn.functional as F +from transformers import PreTrainedTokenizer +from transformers import logging +from transformers.generation import LogitsProcessor + +logger = logging.get_logger(__name__) + +# Types. +HistoryType = List[Tuple[str, str]] +TokensType = List[int] +BatchTokensType = List[List[int]] + + +def pad_batch(batch: BatchTokensType, pad_id: int, seq_length: int) -> BatchTokensType: + for tokens in batch: + context_length = len(tokens) + if context_length < seq_length: + tokens.extend([pad_id] * (seq_length - context_length)) + return batch + + +def get_ltor_masks_and_position_ids( + data, + eod_token, + reset_position_ids, + reset_attention_mask, + eod_mask_loss, +): + """Build masks and position id for left to right model.""" + + # Extract batch size and sequence length. + micro_batch_size, seq_length = data.size() + + # Attention mask (lower triangular). + if reset_attention_mask: + att_mask_batch = micro_batch_size + else: + att_mask_batch = 1 + attention_mask = torch.tril(torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)).view( + att_mask_batch, 1, seq_length, seq_length + ) + + # Loss mask. + loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device) + if eod_mask_loss: + loss_mask[data == eod_token] = 0.0 + + # Position ids. + position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device) + position_ids = position_ids.unsqueeze(0).expand_as(data) + # We need to clone as the ids will be modifed based on batch index. + if reset_position_ids: + position_ids = position_ids.clone() + + if reset_position_ids or reset_attention_mask: + # Loop through the batches: + for b in range(micro_batch_size): + # Find indecies where EOD token is. + eod_index = position_ids[b, data[b] == eod_token] + # Detach indecies from positions if going to modify positions. + if reset_position_ids: + eod_index = eod_index.clone() + + # Loop through EOD indecies: + prev_index = 0 + for j in range(eod_index.size()[0]): + i = eod_index[j] + # Mask attention loss. + if reset_attention_mask: + attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0 + # Reset positions. + if reset_position_ids: + position_ids[b, (i + 1) :] -= i + 1 - prev_index + prev_index = i + 1 + + # Convert attention mask to binary: + attention_mask = attention_mask < 0.5 + + return attention_mask, loss_mask, position_ids + + +def get_batch(context_tokens: torch.LongTensor, eod_id: int): + """Generate batch from context tokens.""" + # Move to GPU. + tokens = context_tokens.contiguous().to(context_tokens.device) + # Get the attention mask and postition ids. + attention_mask, _, position_ids = get_ltor_masks_and_position_ids( + tokens, + eod_id, + reset_position_ids=False, + reset_attention_mask=False, + eod_mask_loss=False, + ) + return tokens, attention_mask, position_ids + + +def make_context( + tokenizer: PreTrainedTokenizer, + query: str, + query_assistant: str = "", + history: List[Tuple[str, str]] = None, + system: str = "", + max_window_size: int = 6144, +): + if history is None: + history = [] + + im_start, im_end = "<|im_start|>", "<|im_end|>" + im_start_tokens = [tokenizer.im_start_id] + im_end_tokens = [tokenizer.im_end_id] + nl_tokens = tokenizer.encode("\n") + + def _tokenize_str(role, content): + return f"{role}\n{content}", tokenizer.encode(role, allowed_special=set()) + nl_tokens + tokenizer.encode( + content, allowed_special=set() + ) + + system_text, system_tokens_part = _tokenize_str("system", system) + system_tokens = im_start_tokens + system_tokens_part + im_end_tokens + assistant_tokens = tokenizer.encode(query_assistant, allowed_special=set()) + raw_text = "" + context_tokens = [] + + for turn_query, turn_response in reversed(history): + query_text, query_tokens_part = _tokenize_str("user", turn_query) + query_tokens = im_start_tokens + query_tokens_part + im_end_tokens + response_text, response_tokens_part = _tokenize_str("assistant", turn_response) + response_tokens = im_start_tokens + response_tokens_part + im_end_tokens + + next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens + prev_chat = f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}" + + current_context_size = len(system_tokens) + len(next_context_tokens) + len(context_tokens) + if current_context_size < max_window_size: + context_tokens = next_context_tokens + context_tokens + raw_text = prev_chat + raw_text + else: + break + + context_tokens = system_tokens + context_tokens + raw_text = f"{im_start}{system_text}{im_end}" + raw_text + context_tokens += ( + nl_tokens + + im_start_tokens + + _tokenize_str("user", query)[1] + + im_end_tokens + + nl_tokens + + im_start_tokens + + tokenizer.encode("assistant") + + nl_tokens + + assistant_tokens + ) + raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n{query_assistant}" + + return raw_text, context_tokens + + +def decode_tokens( + tokens: Union[torch.LongTensor, TokensType], + tokenizer: PreTrainedTokenizer, + raw_text_len: int = 0, + context_length: int = 0, + errors: str = "replace", +) -> str: + if torch.is_tensor(tokens): + tokens = tokens.cpu().numpy().tolist() + + end_reason = f"Gen length {len(tokens)}" + eod_token_idx = context_length + for eod_token_idx in range(context_length, len(tokens)): + if tokens[eod_token_idx] in [tokenizer.im_start_id, tokenizer.im_end_id]: + end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}" + break + + decoded = tokenizer.decode(tokens, errors=errors) + + decode_tokens = tokenizer.decode(tokens[:eod_token_idx], errors=errors) + trim_decode_tokens = decode_tokens[raw_text_len:] + trim_decode_tokens = trim_decode_tokens.strip() + + return decoded, trim_decode_tokens, end_reason + + +class StopWordsLogitsProcessor(LogitsProcessor): + """ + :class:`transformers.LogitsProcessor` that enforces that when specified sequences appear, stop geration. + + Args: + stop_words_ids (:obj:`List[List[int]]`): + List of list of token ids of stop ids. In order to get the tokens of the words + that should not appear in the generated text, use :obj:`tokenizer(bad_word, + add_prefix_space=True).input_ids`. + eos_token_id (:obj:`int`): + The id of the `end-of-sequence` token. + """ + + def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int): + if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0: + raise ValueError(f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}.") + if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids): + raise ValueError(f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}.") + if any( + any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in stop_word_ids) + for stop_word_ids in stop_words_ids + ): + raise ValueError( + f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}." + ) + + self.stop_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids)) + self.eos_token_id = eos_token_id + for stop_token_seq in self.stop_words_ids: + assert len(stop_token_seq) > 0, "Stop words token sequences {} cannot have an empty list".format( + stop_words_ids + ) + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + stopped_samples = self._calc_stopped_samples(input_ids) + for i, should_stop in enumerate(stopped_samples): + if should_stop: + scores[i, self.eos_token_id] = float(2**15) + return scores + + def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool: + if len(tokens) == 0: + # if bad word tokens is just one token always ban it + return True + elif len(tokens) > len(prev_tokens): + # if bad word tokens are longer then prev input_ids they can't be equal + return False + elif prev_tokens[-len(tokens) :].tolist() == tokens: + # if tokens match + return True + else: + return False + + def _calc_stopped_samples(self, prev_input_ids: Iterable[int]) -> Iterable[int]: + stopped_samples = [] + for prev_input_ids_slice in prev_input_ids: + match = False + for stop_token_seq in self.stop_words_ids: + if self._tokens_match(prev_input_ids_slice, stop_token_seq): + # if tokens do not match continue + match = True + break + stopped_samples.append(match) + + return stopped_samples + + +def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")): + """This function has been mostly taken from huggingface conversational + ai code at + https://medium.com/huggingface/how-to-build-a-state-of-the-art- + conversational-ai-with-transfer-learning-2d818ac26313""" + + if top_k > 0: + # Remove all tokens with a probability less than the + # last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p > 0.0: + # Cconvert to 1D + sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + # Shift the indices to the right to keep also the first token + # above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + for i in range(sorted_indices.size(0)): + indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]] + logits[i][indices_to_remove] = filter_value + + return logits + + +def switch(val1, val2, boolean): + boolean = boolean.type_as(val1) + return (1 - boolean) * val1 + boolean * val2 diff --git a/wit/qwen_generation_utils.py b/wit/model/qwen_generation_utils.py similarity index 100% rename from wit/qwen_generation_utils.py rename to wit/model/qwen_generation_utils.py diff --git a/wit/tokenization_qwen.py b/wit/model/tokenization_qwen.py similarity index 100% rename from wit/tokenization_qwen.py rename to wit/model/tokenization_qwen.py diff --git a/wit/tokenizer_config.json b/wit/model/tokenizer_config.json similarity index 100% rename from wit/tokenizer_config.json rename to wit/model/tokenizer_config.json diff --git a/wit/wit_b64.tiktoken b/wit/model/wit_b64.tiktoken similarity index 100% rename from wit/wit_b64.tiktoken rename to wit/model/wit_b64.tiktoken diff --git a/wit/wit_char.tiktoken b/wit/model/wit_char.tiktoken similarity index 100% rename from wit/wit_char.tiktoken rename to wit/model/wit_char.tiktoken diff --git a/wit/stress.py b/wit/stress.py deleted file mode 100644 index c41f5fe..0000000 --- a/wit/stress.py +++ /dev/null @@ -1,82 +0,0 @@ -import pytorch_lightning as pl -import torch -from torch.utils.data import DataLoader, Dataset, random_split - -from lit_module import LitModule -from logger import TBLogger - -from wit.configuration import ModelConfig - -pretrain_model_name = None # "qwen/Qwen-1_8B-Chat" -learning_rate = 0.0001 -use_tril_attention_mask = None -precision = "32-true" # "precision:bf16-mixed,16-mixed,32-true" -train_batch_size = 4 -val_batch_size = 8 -num_proc = 8 -max_epochs = 1000 -strategy = "auto" -resume_from_ckpt_path = None -seed = 42 - - -class StressDataset(Dataset): - def __init__(self, start=1, end=128, size=32768): # 1048576 32768 - self.size = size - self.features = [] - self.data = torch.randint(start, end, [size, 2048]).long() - - def __len__(self): - return self.size - - def __getitem__(self, idx): - output = {} - data = self.data[idx] - output["input_ids"] = data - output["labels"] = data.clone() - output["token_type_ids"] = torch.zeros(data.shape) - return output - - -if __name__ == "__main__": - torch.manual_seed(seed) - - config = ModelConfig() - config.vocab_size = 4096 - config.hidden_size = 1024 # 128 1024 2048 32 - config.num_hidden_layers = 6 # 6 12 24 3 - config.num_attention_heads = 8 # 8 8 16 - - lit_module = LitModule(pretrain_model_name, learning_rate, config, use_tril_attention_mask) - - raw_dataset = StressDataset() - train_dataset, val_dataset = random_split(raw_dataset, [0.95, 0.05]) - - train_dataloader = DataLoader( - train_dataset, - batch_size=train_batch_size, - num_workers=num_proc, - persistent_workers=True, - shuffle=True, - ) - val_dataloader = DataLoader( - val_dataset, - batch_size=val_batch_size, - num_workers=num_proc, - persistent_workers=True, - ) - - lit_trainer = pl.Trainer( - accelerator="gpu", - devices=2, - precision=precision, - logger=TBLogger("./", default_hp_metric=False), - strategy=strategy, - max_epochs=max_epochs, - ) - lit_trainer.fit( - lit_module, - train_dataloaders=train_dataloader, - val_dataloaders=val_dataloader, - ckpt_path=resume_from_ckpt_path, - ) diff --git a/wit/train.py b/wit/train.py index 3e9716a..88a157d 100644 --- a/wit/train.py +++ b/wit/train.py @@ -1,83 +1,45 @@ import pytorch_lightning as pl import torch -from lit_module import LitModule -from tokenization_qwen import QWenTokenizer -from logger import TBLogger +from model.lit_module import LitModule +from wit.model.tokenization_qwen import QWenTokenizer +from logger import MLFLogger -from meaning_dataset import MeaningDataset, BatchGroupMeaningDataloader -from wit.configuration import ModelConfig - -pretrain_model_name = None # "qwen/Qwen-1_8B-Chat" -learning_rate = 0.0001 -use_tril_attention_mask = None -precision = "32-true" # "precision:bf16-mixed,16-mixed,32-true" -train_batch_size = 1 -val_batch_size = 1 -num_proc = 8 -max_epochs = 1000 -strategy = "auto" -resume_from_ckpt_path = None -seed = 42 -dataloader_works = 2 - -vocab_size = 256 -level_ratio = 5 -level = 5 -dataset_level = 3 -min_subitem = 2 - -hidden_size = 128 # 128 1024 2048 32 -num_attention_heads = 16 # 8 8 16 -num_hidden_layers = 6 # 6 12 24 3 - -mask_level = [0, 1, 2] -mask_idx = [0, 0, -1] - -# name = "vocab_ratio_level_data_hidden_head_layer" -# name = "mask_level_idx" -name = "bigger" - -ver = f"{vocab_size}" + "_" + f"{level_ratio}" + "_" + f"{level}" + "_" + f"{min_subitem}" + "_" + f"{dataset_level}" -ver = ver + "_" + f"{hidden_size}" + "_" + f"{num_attention_heads}" + "_" + f"{num_hidden_layers}" -ver = ver + "_" + f"{mask_level}" + "_" + f"{mask_idx}" +import configuration +import dataset as ds if __name__ == "__main__": - torch.manual_seed(seed) - config = ModelConfig() - config.vocab_size = vocab_size - config.hidden_size = hidden_size - config.num_hidden_layers = num_hidden_layers - config.num_attention_heads = num_attention_heads + train_config = configuration.TrainConfig() + config = train_config.model_config - lit_module = LitModule(pretrain_model_name, learning_rate, config, use_tril_attention_mask) - tokenizer = QWenTokenizer("./wit_b64.tiktoken", "./wit_char.tiktoken") + torch.manual_seed(train_config.seed) - start = vocab_size * (level_ratio**level) - size = vocab_size * int((level_ratio**dataset_level)) + config.vocab_size = 256 + config.hidden_size = 128 # 128 1024 2048 32 + config.num_hidden_layers = 6 # 6 12 24 3 + config.num_attention_heads = 16 # 8 8 16 - raw_dataset = MeaningDataset(start, start + size, vocab_size, None, level_ratio, min_subitem) - # print(raw_dataset.token_frequency()) - raw_dataset.set_mask(mask_level, mask_idx) - train_dataset, val_dataset = raw_dataset.split(0.9) - train_dataloader = BatchGroupMeaningDataloader(train_dataset, train_batch_size).dataloader(dataloader_works) - val_dataloader = BatchGroupMeaningDataloader(val_dataset, val_batch_size).dataloader(dataloader_works) + lit_module = LitModule( + train_config.pretrain_model_name, train_config.learning_rate, config, train_config.use_tril_attention_mask + ) + tokenizer = QWenTokenizer("./model/wit_b64.tiktoken", "./model/wit_char.tiktoken") + train_dataloader, val_dataloader = ds.InitDataset(train_config) # for i in range(len(train_dataloader)): # print(train_dataloader.print_mapping(i)) torch.set_float32_matmul_precision("medium") lit_trainer = pl.Trainer( accelerator="cuda", - precision=precision, - logger=TBLogger("./log/", name=name, version=ver, default_hp_metric=False), - strategy=strategy, - max_epochs=max_epochs, + precision=train_config.precision, + logger=MLFLogger("./log/", run_name=train_config.name), + strategy=train_config.strategy, + max_epochs=train_config.max_epochs, ) lit_trainer.fit( lit_module, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, - ckpt_path=resume_from_ckpt_path, + ckpt_path=train_config.resume_from_ckpt_path, ) diff --git a/wit/train_special.py b/wit/train_special.py deleted file mode 100644 index 65d3775..0000000 --- a/wit/train_special.py +++ /dev/null @@ -1,79 +0,0 @@ -import argparse -from functools import partial -from itertools import chain -from typing import Dict, Tuple - -import datasets -import pytorch_lightning as pl -import torch -from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split, Subset - -from lit_module import LitModule -from tokenization_qwen import QWenTokenizer -from logger import TBLogger - -from special_dataset import SpecialDataset -from meaning_dataset import MeaningDataset -from wit.configuration import ModelConfig - -pretrain_model_name = None # "qwen/Qwen-1_8B-Chat" -learning_rate = 0.0001 -use_tril_attention_mask = None -precision = "32-true" # "precision:bf16-mixed,16-mixed,32-true" -train_batch_size = 128 -val_batch_size = 128 -num_proc = 8 -max_epochs = 1000 -strategy = "auto" -resume_from_ckpt_path = None -seed = 42 -vocab_size = 256 - - -if __name__ == "__main__": - torch.manual_seed(seed) - - config = ModelConfig() - config.vocab_size = vocab_size - config.hidden_size = 128 # 128 1024 2048 32 - config.num_hidden_layers = 3 # 6 12 24 3 - config.num_attention_heads = 8 # 8 8 16 - - lit_module = LitModule(pretrain_model_name, learning_rate, config, use_tril_attention_mask) - tokenizer = QWenTokenizer("./wit_b64.tiktoken", "./wit_char.tiktoken") - - raw_dataset = SpecialDataset() - train_dataset, val_dataset = random_split(raw_dataset, [0.95, 0.05]) - it = iter(train_dataset) - print("data samples:") - for i in range(10): - print(next(it)["input_ids"].numpy().tolist()) - - train_dataloader = DataLoader( - train_dataset, - batch_size=train_batch_size, - num_workers=num_proc, - persistent_workers=True, - shuffle=True, - ) - val_dataloader = DataLoader( - val_dataset, - batch_size=val_batch_size, - num_workers=num_proc, - persistent_workers=True, - ) - - torch.set_float32_matmul_precision("medium") - lit_trainer = pl.Trainer( - accelerator="gpu", - precision=precision, - logger=TBLogger("./", default_hp_metric=False), - strategy=strategy, - max_epochs=max_epochs, - ) - lit_trainer.fit( - lit_module, - train_dataloaders=train_dataloader, - val_dataloaders=val_dataloader, - ckpt_path=resume_from_ckpt_path, - )