Regine wit config method.
This commit is contained in:
		
							parent
							
								
									cdee69bf54
								
							
						
					
					
						commit
						e635ce0df4
					
				|  | @ -9,4 +9,6 @@ checkpoints | |||
| build | ||||
| log | ||||
| logs | ||||
| data | ||||
| data | ||||
| 
 | ||||
| mlruns | ||||
|  | @ -1,2 +0,0 @@ | |||
| from qwen.modeling_qwen import QWenLMHeadModel | ||||
| from qwen.configuration_qwen import QWenConfig | ||||
|  | @ -1,7 +1,3 @@ | |||
| # Copyright (c) Alibaba Cloud. | ||||
| # | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
| 
 | ||||
| 
 | ||||
| class ModelConfig: | ||||
|  | @ -40,3 +36,37 @@ class ModelConfig: | |||
|         self.top_p = 0.8 | ||||
|         self.repetition_penalty = 1.1 | ||||
|         self.model_max_length = 8192 | ||||
| 
 | ||||
| 
 | ||||
| class MeaningDatasetConfig: | ||||
|     def __init__(self): | ||||
|         self.level_ratio = 5 | ||||
|         self.level = 5 | ||||
|         self.dataset_level = 3 | ||||
|         self.min_subitem = 2 | ||||
|         self.mask_level = [0, 1, 2] | ||||
|         self.mask_idx = [0, 0, -1] | ||||
| 
 | ||||
| class DatasetConfig: | ||||
|     def __init__(self): | ||||
|         self.name = "meaning" | ||||
|         self.meaning = MeaningDatasetConfig() | ||||
| 
 | ||||
| class TrainConfig: | ||||
|     def __init__(self): | ||||
|         self.name = "bigger" # current train process name  | ||||
|         self.pretrain_model_name = None  # "qwen/Qwen-1_8B-Chat" | ||||
|         self.learning_rate = 0.0001 | ||||
|         self.use_tril_attention_mask = None | ||||
|         self.precision = "16-mixed"  # "precision:bf16-mixed,16-mixed,32-true" | ||||
|         self.train_batch_size = 4 | ||||
|         self.val_batch_size = 4 | ||||
|         self.num_proc = 8 | ||||
|         self.max_epochs = 1000 | ||||
|         self.strategy = "auto" | ||||
|         self.resume_from_ckpt_path = None | ||||
|         self.seed = 42 | ||||
|         self.dataloader_works = 2 | ||||
| 
 | ||||
|         self.model_config = ModelConfig() | ||||
|         self.dataset = DatasetConfig() | ||||
|  | @ -0,0 +1,42 @@ | |||
| from meaning_dataset import MeaningDataset, BatchGroupMeaningDataloader | ||||
| from special_dataset import SpecialDataset | ||||
| from torch.utils.data import random_split, DataLoader | ||||
| 
 | ||||
| 
 | ||||
| def InitDataset(config): | ||||
| 
 | ||||
|     train_batch_size = config.train_batch_size | ||||
|     val_batch_size = config.val_batch_size | ||||
|     num_proc = config.num_proc | ||||
|     if config.dataset.name == "special": | ||||
|         raw_dataset = SpecialDataset() | ||||
|         train_dataset, val_dataset = random_split(raw_dataset, [0.95, 0.05]) | ||||
|         train_dataloader = DataLoader( | ||||
|             train_dataset, | ||||
|             batch_size=train_batch_size, | ||||
|             num_workers=num_proc, | ||||
|             persistent_workers=True, | ||||
|             shuffle=True, | ||||
|         ) | ||||
|         val_dataloader = DataLoader( | ||||
|             val_dataset, | ||||
|             batch_size=val_batch_size, | ||||
|             num_workers=num_proc, | ||||
|             persistent_workers=True, | ||||
|         ) | ||||
|         return train_dataloader, val_dataloader | ||||
| 
 | ||||
|     if config.dataset.name == "meaning": | ||||
|         conf = config.dataset.meaning | ||||
|         vocab = config.model_config.vocab_size | ||||
|         start = vocab * (conf.level_ratio**conf.level) | ||||
|         size = vocab * int((conf.level_ratio**conf.dataset_level)) | ||||
|         raw_dataset = MeaningDataset(start, start + size, vocab, None, conf.level_ratio, conf.min_subitem) | ||||
|         # print(raw_dataset.token_frequency()) | ||||
|         raw_dataset.set_mask(conf.mask_level, conf.mask_idx) | ||||
|         train_dataset, val_dataset = raw_dataset.split(0.9) | ||||
|         train_dataloader = BatchGroupMeaningDataloader(train_dataset, train_batch_size).dataloader( | ||||
|             config.dataloader_works | ||||
|         ) | ||||
|         val_dataloader = BatchGroupMeaningDataloader(val_dataset, val_batch_size).dataloader(config.dataloader_works) | ||||
|         return train_dataloader, val_dataloader | ||||
|  | @ -2,13 +2,13 @@ import torch | |||
| import sys | ||||
| from modelscope import snapshot_download | ||||
| 
 | ||||
| from modeling_wit import QWenLMHeadModel | ||||
| from modeling_wit import QwenRunner | ||||
| from wit.model.modeling_wit import QWenLMHeadModel | ||||
| from wit.model.modeling_wit import QwenRunner | ||||
| from wit.configuration import ModelConfig | ||||
| from tokenization_qwen import QWenTokenizer | ||||
| from wit.model.tokenization_qwen import QWenTokenizer | ||||
| 
 | ||||
| 
 | ||||
| from qwen_generation_utils import ( | ||||
| from wit.model.qwen_generation_utils import ( | ||||
|     make_context, | ||||
|     decode_tokens, | ||||
| ) | ||||
|  |  | |||
|  | @ -9,10 +9,9 @@ import torch | |||
| from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split, Subset | ||||
| 
 | ||||
| from lit_module import LitModule | ||||
| from tokenization_qwen import QWenTokenizer | ||||
| from wit.model.tokenization_qwen import QWenTokenizer | ||||
| from logger import TBLogger | ||||
| 
 | ||||
| from special_dataset import SpecialDataset | ||||
| from meaning_dataset import MeaningDataset, BatchGroupMeaningDataloader | ||||
| from wit.configuration import ModelConfig | ||||
| 
 | ||||
|  |  | |||
|  | @ -5,10 +5,8 @@ import pytorch_lightning as pl | |||
| import torch | ||||
| import torchmetrics | ||||
| 
 | ||||
| from modeling_wit import QWenLMHeadModel | ||||
| from wit.configuration import ModelConfig | ||||
| 
 | ||||
| from transformers import AutoConfig | ||||
| from model.modeling_wit import QWenLMHeadModel | ||||
| from configuration import ModelConfig | ||||
| 
 | ||||
| 
 | ||||
| class LitModule(pl.LightningModule): | ||||
|  | @ -63,7 +61,7 @@ class LitModule(pl.LightningModule): | |||
|         logits = logits.contiguous().view(-1, logits.size(-1)) | ||||
|         labels = batch["labels"][..., 1:] | ||||
|         labels = labels.contiguous().view(-1) | ||||
|         if batch["mask"] != None: | ||||
|         if "mask" in batch and batch["mask"] != None: | ||||
|             label_mask = batch["mask"][..., 1:] | ||||
|             label_mask = label_mask.contiguous().view(-1) | ||||
|             logits = logits[label_mask] | ||||
|  | @ -16,7 +16,7 @@ from torch import nn | |||
| from safetensors.torch import load_file as safe_load_file | ||||
| from safetensors.torch import save_file as safe_save_file | ||||
| 
 | ||||
| from qwen_generation_utils import ( | ||||
| from model.qwen_generation_utils import ( | ||||
|     make_context, | ||||
|     decode_tokens, | ||||
| ) | ||||
|  | @ -0,0 +1,294 @@ | |||
| # Copyright (c) Alibaba Cloud. | ||||
| # | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
| 
 | ||||
| """Generation support.""" | ||||
| 
 | ||||
| from typing import Tuple, List, Union, Iterable | ||||
| 
 | ||||
| import numpy as np | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
| from transformers import PreTrainedTokenizer | ||||
| from transformers import logging | ||||
| from transformers.generation import LogitsProcessor | ||||
| 
 | ||||
| logger = logging.get_logger(__name__) | ||||
| 
 | ||||
| # Types. | ||||
| HistoryType = List[Tuple[str, str]] | ||||
| TokensType = List[int] | ||||
| BatchTokensType = List[List[int]] | ||||
| 
 | ||||
| 
 | ||||
| def pad_batch(batch: BatchTokensType, pad_id: int, seq_length: int) -> BatchTokensType: | ||||
|     for tokens in batch: | ||||
|         context_length = len(tokens) | ||||
|         if context_length < seq_length: | ||||
|             tokens.extend([pad_id] * (seq_length - context_length)) | ||||
|     return batch | ||||
| 
 | ||||
| 
 | ||||
| def get_ltor_masks_and_position_ids( | ||||
|     data, | ||||
|     eod_token, | ||||
|     reset_position_ids, | ||||
|     reset_attention_mask, | ||||
|     eod_mask_loss, | ||||
| ): | ||||
|     """Build masks and position id for left to right model.""" | ||||
| 
 | ||||
|     # Extract batch size and sequence length. | ||||
|     micro_batch_size, seq_length = data.size() | ||||
| 
 | ||||
|     # Attention mask (lower triangular). | ||||
|     if reset_attention_mask: | ||||
|         att_mask_batch = micro_batch_size | ||||
|     else: | ||||
|         att_mask_batch = 1 | ||||
|     attention_mask = torch.tril(torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)).view( | ||||
|         att_mask_batch, 1, seq_length, seq_length | ||||
|     ) | ||||
| 
 | ||||
|     # Loss mask. | ||||
|     loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device) | ||||
|     if eod_mask_loss: | ||||
|         loss_mask[data == eod_token] = 0.0 | ||||
| 
 | ||||
|     # Position ids. | ||||
|     position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device) | ||||
|     position_ids = position_ids.unsqueeze(0).expand_as(data) | ||||
|     # We need to clone as the ids will be modifed based on batch index. | ||||
|     if reset_position_ids: | ||||
|         position_ids = position_ids.clone() | ||||
| 
 | ||||
|     if reset_position_ids or reset_attention_mask: | ||||
|         # Loop through the batches: | ||||
|         for b in range(micro_batch_size): | ||||
|             # Find indecies where EOD token is. | ||||
|             eod_index = position_ids[b, data[b] == eod_token] | ||||
|             # Detach indecies from positions if going to modify positions. | ||||
|             if reset_position_ids: | ||||
|                 eod_index = eod_index.clone() | ||||
| 
 | ||||
|             # Loop through EOD indecies: | ||||
|             prev_index = 0 | ||||
|             for j in range(eod_index.size()[0]): | ||||
|                 i = eod_index[j] | ||||
|                 # Mask attention loss. | ||||
|                 if reset_attention_mask: | ||||
|                     attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0 | ||||
|                 # Reset positions. | ||||
|                 if reset_position_ids: | ||||
|                     position_ids[b, (i + 1) :] -= i + 1 - prev_index | ||||
|                     prev_index = i + 1 | ||||
| 
 | ||||
|     # Convert attention mask to binary: | ||||
|     attention_mask = attention_mask < 0.5 | ||||
| 
 | ||||
|     return attention_mask, loss_mask, position_ids | ||||
| 
 | ||||
| 
 | ||||
| def get_batch(context_tokens: torch.LongTensor, eod_id: int): | ||||
|     """Generate batch from context tokens.""" | ||||
|     # Move to GPU. | ||||
|     tokens = context_tokens.contiguous().to(context_tokens.device) | ||||
|     # Get the attention mask and postition ids. | ||||
|     attention_mask, _, position_ids = get_ltor_masks_and_position_ids( | ||||
|         tokens, | ||||
|         eod_id, | ||||
|         reset_position_ids=False, | ||||
|         reset_attention_mask=False, | ||||
|         eod_mask_loss=False, | ||||
|     ) | ||||
|     return tokens, attention_mask, position_ids | ||||
| 
 | ||||
| 
 | ||||
| def make_context( | ||||
|     tokenizer: PreTrainedTokenizer, | ||||
|     query: str, | ||||
|     query_assistant: str = "", | ||||
|     history: List[Tuple[str, str]] = None, | ||||
|     system: str = "", | ||||
|     max_window_size: int = 6144, | ||||
| ): | ||||
|     if history is None: | ||||
|         history = [] | ||||
| 
 | ||||
|     im_start, im_end = "<|im_start|>", "<|im_end|>" | ||||
|     im_start_tokens = [tokenizer.im_start_id] | ||||
|     im_end_tokens = [tokenizer.im_end_id] | ||||
|     nl_tokens = tokenizer.encode("\n") | ||||
| 
 | ||||
|     def _tokenize_str(role, content): | ||||
|         return f"{role}\n{content}", tokenizer.encode(role, allowed_special=set()) + nl_tokens + tokenizer.encode( | ||||
|             content, allowed_special=set() | ||||
|         ) | ||||
| 
 | ||||
|     system_text, system_tokens_part = _tokenize_str("system", system) | ||||
|     system_tokens = im_start_tokens + system_tokens_part + im_end_tokens | ||||
|     assistant_tokens = tokenizer.encode(query_assistant, allowed_special=set()) | ||||
|     raw_text = "" | ||||
|     context_tokens = [] | ||||
| 
 | ||||
|     for turn_query, turn_response in reversed(history): | ||||
|         query_text, query_tokens_part = _tokenize_str("user", turn_query) | ||||
|         query_tokens = im_start_tokens + query_tokens_part + im_end_tokens | ||||
|         response_text, response_tokens_part = _tokenize_str("assistant", turn_response) | ||||
|         response_tokens = im_start_tokens + response_tokens_part + im_end_tokens | ||||
| 
 | ||||
|         next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens | ||||
|         prev_chat = f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}" | ||||
| 
 | ||||
|         current_context_size = len(system_tokens) + len(next_context_tokens) + len(context_tokens) | ||||
|         if current_context_size < max_window_size: | ||||
|             context_tokens = next_context_tokens + context_tokens | ||||
|             raw_text = prev_chat + raw_text | ||||
|         else: | ||||
|             break | ||||
| 
 | ||||
|     context_tokens = system_tokens + context_tokens | ||||
|     raw_text = f"{im_start}{system_text}{im_end}" + raw_text | ||||
|     context_tokens += ( | ||||
|         nl_tokens | ||||
|         + im_start_tokens | ||||
|         + _tokenize_str("user", query)[1] | ||||
|         + im_end_tokens | ||||
|         + nl_tokens | ||||
|         + im_start_tokens | ||||
|         + tokenizer.encode("assistant") | ||||
|         + nl_tokens | ||||
|         + assistant_tokens | ||||
|     ) | ||||
|     raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n{query_assistant}" | ||||
| 
 | ||||
|     return raw_text, context_tokens | ||||
| 
 | ||||
| 
 | ||||
| def decode_tokens( | ||||
|     tokens: Union[torch.LongTensor, TokensType], | ||||
|     tokenizer: PreTrainedTokenizer, | ||||
|     raw_text_len: int = 0, | ||||
|     context_length: int = 0, | ||||
|     errors: str = "replace", | ||||
| ) -> str: | ||||
|     if torch.is_tensor(tokens): | ||||
|         tokens = tokens.cpu().numpy().tolist() | ||||
| 
 | ||||
|     end_reason = f"Gen length {len(tokens)}" | ||||
|     eod_token_idx = context_length | ||||
|     for eod_token_idx in range(context_length, len(tokens)): | ||||
|         if tokens[eod_token_idx] in [tokenizer.im_start_id, tokenizer.im_end_id]: | ||||
|             end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}" | ||||
|             break | ||||
| 
 | ||||
|     decoded = tokenizer.decode(tokens, errors=errors) | ||||
| 
 | ||||
|     decode_tokens = tokenizer.decode(tokens[:eod_token_idx], errors=errors) | ||||
|     trim_decode_tokens = decode_tokens[raw_text_len:] | ||||
|     trim_decode_tokens = trim_decode_tokens.strip() | ||||
| 
 | ||||
|     return decoded, trim_decode_tokens, end_reason | ||||
| 
 | ||||
| 
 | ||||
| class StopWordsLogitsProcessor(LogitsProcessor): | ||||
|     """ | ||||
|     :class:`transformers.LogitsProcessor` that enforces that when specified sequences appear, stop geration. | ||||
| 
 | ||||
|     Args: | ||||
|         stop_words_ids (:obj:`List[List[int]]`): | ||||
|             List of list of token ids of stop ids. In order to get the tokens of the words | ||||
|             that should not appear in the generated text, use :obj:`tokenizer(bad_word, | ||||
|             add_prefix_space=True).input_ids`. | ||||
|         eos_token_id (:obj:`int`): | ||||
|             The id of the `end-of-sequence` token. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int): | ||||
|         if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0: | ||||
|             raise ValueError(f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}.") | ||||
|         if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids): | ||||
|             raise ValueError(f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}.") | ||||
|         if any( | ||||
|             any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in stop_word_ids) | ||||
|             for stop_word_ids in stop_words_ids | ||||
|         ): | ||||
|             raise ValueError( | ||||
|                 f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}." | ||||
|             ) | ||||
| 
 | ||||
|         self.stop_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids)) | ||||
|         self.eos_token_id = eos_token_id | ||||
|         for stop_token_seq in self.stop_words_ids: | ||||
|             assert len(stop_token_seq) > 0, "Stop words token sequences {} cannot have an empty list".format( | ||||
|                 stop_words_ids | ||||
|             ) | ||||
| 
 | ||||
|     def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | ||||
|         stopped_samples = self._calc_stopped_samples(input_ids) | ||||
|         for i, should_stop in enumerate(stopped_samples): | ||||
|             if should_stop: | ||||
|                 scores[i, self.eos_token_id] = float(2**15) | ||||
|         return scores | ||||
| 
 | ||||
|     def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool: | ||||
|         if len(tokens) == 0: | ||||
|             # if bad word tokens is just one token always ban it | ||||
|             return True | ||||
|         elif len(tokens) > len(prev_tokens): | ||||
|             # if bad word tokens are longer then prev input_ids they can't be equal | ||||
|             return False | ||||
|         elif prev_tokens[-len(tokens) :].tolist() == tokens: | ||||
|             # if tokens match | ||||
|             return True | ||||
|         else: | ||||
|             return False | ||||
| 
 | ||||
|     def _calc_stopped_samples(self, prev_input_ids: Iterable[int]) -> Iterable[int]: | ||||
|         stopped_samples = [] | ||||
|         for prev_input_ids_slice in prev_input_ids: | ||||
|             match = False | ||||
|             for stop_token_seq in self.stop_words_ids: | ||||
|                 if self._tokens_match(prev_input_ids_slice, stop_token_seq): | ||||
|                     # if tokens do not match continue | ||||
|                     match = True | ||||
|                     break | ||||
|             stopped_samples.append(match) | ||||
| 
 | ||||
|         return stopped_samples | ||||
| 
 | ||||
| 
 | ||||
| def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")): | ||||
|     """This function has been mostly taken from huggingface conversational | ||||
|     ai code at | ||||
|         https://medium.com/huggingface/how-to-build-a-state-of-the-art- | ||||
|              conversational-ai-with-transfer-learning-2d818ac26313""" | ||||
| 
 | ||||
|     if top_k > 0: | ||||
|         # Remove all tokens with a probability less than the | ||||
|         # last token of the top-k | ||||
|         indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] | ||||
|         logits[indices_to_remove] = filter_value | ||||
| 
 | ||||
|     if top_p > 0.0: | ||||
|         # Cconvert to 1D | ||||
|         sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) | ||||
|         cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | ||||
| 
 | ||||
|         # Remove tokens with cumulative probability above the threshold | ||||
|         sorted_indices_to_remove = cumulative_probs > top_p | ||||
|         # Shift the indices to the right to keep also the first token | ||||
|         # above the threshold | ||||
|         sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | ||||
|         sorted_indices_to_remove[..., 0] = 0 | ||||
|         for i in range(sorted_indices.size(0)): | ||||
|             indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]] | ||||
|             logits[i][indices_to_remove] = filter_value | ||||
| 
 | ||||
|     return logits | ||||
| 
 | ||||
| 
 | ||||
| def switch(val1, val2, boolean): | ||||
|     boolean = boolean.type_as(val1) | ||||
|     return (1 - boolean) * val1 + boolean * val2 | ||||
|  | @ -1,82 +0,0 @@ | |||
| import pytorch_lightning as pl | ||||
| import torch | ||||
| from torch.utils.data import DataLoader, Dataset, random_split | ||||
| 
 | ||||
| from lit_module import LitModule | ||||
| from logger import TBLogger | ||||
| 
 | ||||
| from wit.configuration import ModelConfig | ||||
| 
 | ||||
| pretrain_model_name = None  # "qwen/Qwen-1_8B-Chat" | ||||
| learning_rate = 0.0001 | ||||
| use_tril_attention_mask = None | ||||
| precision = "32-true"  # "precision:bf16-mixed,16-mixed,32-true" | ||||
| train_batch_size = 4 | ||||
| val_batch_size = 8 | ||||
| num_proc = 8 | ||||
| max_epochs = 1000 | ||||
| strategy = "auto" | ||||
| resume_from_ckpt_path = None | ||||
| seed = 42 | ||||
| 
 | ||||
| 
 | ||||
| class StressDataset(Dataset): | ||||
|     def __init__(self, start=1, end=128, size=32768):  # 1048576 32768 | ||||
|         self.size = size | ||||
|         self.features = [] | ||||
|         self.data = torch.randint(start, end, [size, 2048]).long() | ||||
| 
 | ||||
|     def __len__(self): | ||||
|         return self.size | ||||
| 
 | ||||
|     def __getitem__(self, idx): | ||||
|         output = {} | ||||
|         data = self.data[idx] | ||||
|         output["input_ids"] = data | ||||
|         output["labels"] = data.clone() | ||||
|         output["token_type_ids"] = torch.zeros(data.shape) | ||||
|         return output | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     torch.manual_seed(seed) | ||||
| 
 | ||||
|     config = ModelConfig() | ||||
|     config.vocab_size = 4096 | ||||
|     config.hidden_size = 1024  # 128 1024 2048  32 | ||||
|     config.num_hidden_layers = 6  # 6 12 24  3 | ||||
|     config.num_attention_heads = 8  # 8 8 16 | ||||
| 
 | ||||
|     lit_module = LitModule(pretrain_model_name, learning_rate, config, use_tril_attention_mask) | ||||
| 
 | ||||
|     raw_dataset = StressDataset() | ||||
|     train_dataset, val_dataset = random_split(raw_dataset, [0.95, 0.05]) | ||||
| 
 | ||||
|     train_dataloader = DataLoader( | ||||
|         train_dataset, | ||||
|         batch_size=train_batch_size, | ||||
|         num_workers=num_proc, | ||||
|         persistent_workers=True, | ||||
|         shuffle=True, | ||||
|     ) | ||||
|     val_dataloader = DataLoader( | ||||
|         val_dataset, | ||||
|         batch_size=val_batch_size, | ||||
|         num_workers=num_proc, | ||||
|         persistent_workers=True, | ||||
|     ) | ||||
| 
 | ||||
|     lit_trainer = pl.Trainer( | ||||
|         accelerator="gpu", | ||||
|         devices=2, | ||||
|         precision=precision, | ||||
|         logger=TBLogger("./", default_hp_metric=False), | ||||
|         strategy=strategy, | ||||
|         max_epochs=max_epochs, | ||||
|     ) | ||||
|     lit_trainer.fit( | ||||
|         lit_module, | ||||
|         train_dataloaders=train_dataloader, | ||||
|         val_dataloaders=val_dataloader, | ||||
|         ckpt_path=resume_from_ckpt_path, | ||||
|     ) | ||||
							
								
								
									
										82
									
								
								wit/train.py
								
								
								
								
							
							
						
						
									
										82
									
								
								wit/train.py
								
								
								
								
							|  | @ -1,83 +1,45 @@ | |||
| import pytorch_lightning as pl | ||||
| import torch | ||||
| 
 | ||||
| from lit_module import LitModule | ||||
| from tokenization_qwen import QWenTokenizer | ||||
| from logger import TBLogger | ||||
| from model.lit_module import LitModule | ||||
| from wit.model.tokenization_qwen import QWenTokenizer | ||||
| from logger import MLFLogger | ||||
| 
 | ||||
| from meaning_dataset import MeaningDataset, BatchGroupMeaningDataloader | ||||
| from wit.configuration import ModelConfig | ||||
| 
 | ||||
| pretrain_model_name = None  # "qwen/Qwen-1_8B-Chat" | ||||
| learning_rate = 0.0001 | ||||
| use_tril_attention_mask = None | ||||
| precision = "32-true"  # "precision:bf16-mixed,16-mixed,32-true" | ||||
| train_batch_size = 1 | ||||
| val_batch_size = 1 | ||||
| num_proc = 8 | ||||
| max_epochs = 1000 | ||||
| strategy = "auto" | ||||
| resume_from_ckpt_path = None | ||||
| seed = 42 | ||||
| dataloader_works = 2 | ||||
| 
 | ||||
| vocab_size = 256 | ||||
| level_ratio = 5 | ||||
| level = 5 | ||||
| dataset_level = 3 | ||||
| min_subitem = 2 | ||||
| 
 | ||||
| hidden_size = 128  # 128 1024 2048  32 | ||||
| num_attention_heads = 16  # 8 8 16 | ||||
| num_hidden_layers = 6  # 6 12 24  3 | ||||
| 
 | ||||
| mask_level = [0, 1, 2] | ||||
| mask_idx = [0, 0, -1] | ||||
| 
 | ||||
| # name = "vocab_ratio_level_data_hidden_head_layer" | ||||
| # name = "mask_level_idx" | ||||
| name = "bigger" | ||||
| 
 | ||||
| ver = f"{vocab_size}" + "_" + f"{level_ratio}" + "_" + f"{level}" + "_" + f"{min_subitem}" + "_" + f"{dataset_level}" | ||||
| ver = ver + "_" + f"{hidden_size}" + "_" + f"{num_attention_heads}" + "_" + f"{num_hidden_layers}" | ||||
| ver = ver + "_" + f"{mask_level}" + "_" + f"{mask_idx}" | ||||
| import configuration | ||||
| import dataset as ds | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     torch.manual_seed(seed) | ||||
| 
 | ||||
|     config = ModelConfig() | ||||
|     config.vocab_size = vocab_size | ||||
|     config.hidden_size = hidden_size | ||||
|     config.num_hidden_layers = num_hidden_layers | ||||
|     config.num_attention_heads = num_attention_heads | ||||
|     train_config = configuration.TrainConfig() | ||||
|     config = train_config.model_config | ||||
| 
 | ||||
|     lit_module = LitModule(pretrain_model_name, learning_rate, config, use_tril_attention_mask) | ||||
|     tokenizer = QWenTokenizer("./wit_b64.tiktoken", "./wit_char.tiktoken") | ||||
|     torch.manual_seed(train_config.seed) | ||||
| 
 | ||||
|     start = vocab_size * (level_ratio**level) | ||||
|     size = vocab_size * int((level_ratio**dataset_level)) | ||||
|     config.vocab_size = 256 | ||||
|     config.hidden_size = 128  # 128 1024 2048  32 | ||||
|     config.num_hidden_layers = 6  # 6 12 24  3 | ||||
|     config.num_attention_heads = 16  # 8 8 16 | ||||
| 
 | ||||
|     raw_dataset = MeaningDataset(start, start + size, vocab_size, None, level_ratio, min_subitem) | ||||
|     # print(raw_dataset.token_frequency()) | ||||
|     raw_dataset.set_mask(mask_level, mask_idx) | ||||
|     train_dataset, val_dataset = raw_dataset.split(0.9) | ||||
|     train_dataloader = BatchGroupMeaningDataloader(train_dataset, train_batch_size).dataloader(dataloader_works) | ||||
|     val_dataloader = BatchGroupMeaningDataloader(val_dataset, val_batch_size).dataloader(dataloader_works) | ||||
|     lit_module = LitModule( | ||||
|         train_config.pretrain_model_name, train_config.learning_rate, config, train_config.use_tril_attention_mask | ||||
|     ) | ||||
|     tokenizer = QWenTokenizer("./model/wit_b64.tiktoken", "./model/wit_char.tiktoken") | ||||
| 
 | ||||
|     train_dataloader, val_dataloader = ds.InitDataset(train_config) | ||||
|     # for i in range(len(train_dataloader)): | ||||
|     #     print(train_dataloader.print_mapping(i)) | ||||
| 
 | ||||
|     torch.set_float32_matmul_precision("medium") | ||||
|     lit_trainer = pl.Trainer( | ||||
|         accelerator="cuda", | ||||
|         precision=precision, | ||||
|         logger=TBLogger("./log/", name=name, version=ver, default_hp_metric=False), | ||||
|         strategy=strategy, | ||||
|         max_epochs=max_epochs, | ||||
|         precision=train_config.precision, | ||||
|         logger=MLFLogger("./log/", run_name=train_config.name), | ||||
|         strategy=train_config.strategy, | ||||
|         max_epochs=train_config.max_epochs, | ||||
|     ) | ||||
|     lit_trainer.fit( | ||||
|         lit_module, | ||||
|         train_dataloaders=train_dataloader, | ||||
|         val_dataloaders=val_dataloader, | ||||
|         ckpt_path=resume_from_ckpt_path, | ||||
|         ckpt_path=train_config.resume_from_ckpt_path, | ||||
|     ) | ||||
|  |  | |||
|  | @ -1,79 +0,0 @@ | |||
| import argparse | ||||
| from functools import partial | ||||
| from itertools import chain | ||||
| from typing import Dict, Tuple | ||||
| 
 | ||||
| import datasets | ||||
| import pytorch_lightning as pl | ||||
| import torch | ||||
| from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split, Subset | ||||
| 
 | ||||
| from lit_module import LitModule | ||||
| from tokenization_qwen import QWenTokenizer | ||||
| from logger import TBLogger | ||||
| 
 | ||||
| from special_dataset import SpecialDataset | ||||
| from meaning_dataset import MeaningDataset | ||||
| from wit.configuration import ModelConfig | ||||
| 
 | ||||
| pretrain_model_name = None  # "qwen/Qwen-1_8B-Chat" | ||||
| learning_rate = 0.0001 | ||||
| use_tril_attention_mask = None | ||||
| precision = "32-true"  # "precision:bf16-mixed,16-mixed,32-true" | ||||
| train_batch_size = 128 | ||||
| val_batch_size = 128 | ||||
| num_proc = 8 | ||||
| max_epochs = 1000 | ||||
| strategy = "auto" | ||||
| resume_from_ckpt_path = None | ||||
| seed = 42 | ||||
| vocab_size = 256 | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     torch.manual_seed(seed) | ||||
| 
 | ||||
|     config = ModelConfig() | ||||
|     config.vocab_size = vocab_size | ||||
|     config.hidden_size = 128  # 128 1024 2048  32 | ||||
|     config.num_hidden_layers = 3  # 6 12 24  3 | ||||
|     config.num_attention_heads = 8  # 8 8 16 | ||||
| 
 | ||||
|     lit_module = LitModule(pretrain_model_name, learning_rate, config, use_tril_attention_mask) | ||||
|     tokenizer = QWenTokenizer("./wit_b64.tiktoken", "./wit_char.tiktoken") | ||||
| 
 | ||||
|     raw_dataset = SpecialDataset() | ||||
|     train_dataset, val_dataset = random_split(raw_dataset, [0.95, 0.05]) | ||||
|     it = iter(train_dataset) | ||||
|     print("data samples:") | ||||
|     for i in range(10): | ||||
|         print(next(it)["input_ids"].numpy().tolist()) | ||||
| 
 | ||||
|     train_dataloader = DataLoader( | ||||
|         train_dataset, | ||||
|         batch_size=train_batch_size, | ||||
|         num_workers=num_proc, | ||||
|         persistent_workers=True, | ||||
|         shuffle=True, | ||||
|     ) | ||||
|     val_dataloader = DataLoader( | ||||
|         val_dataset, | ||||
|         batch_size=val_batch_size, | ||||
|         num_workers=num_proc, | ||||
|         persistent_workers=True, | ||||
|     ) | ||||
| 
 | ||||
|     torch.set_float32_matmul_precision("medium") | ||||
|     lit_trainer = pl.Trainer( | ||||
|         accelerator="gpu", | ||||
|         precision=precision, | ||||
|         logger=TBLogger("./", default_hp_metric=False), | ||||
|         strategy=strategy, | ||||
|         max_epochs=max_epochs, | ||||
|     ) | ||||
|     lit_trainer.fit( | ||||
|         lit_module, | ||||
|         train_dataloaders=train_dataloader, | ||||
|         val_dataloaders=val_dataloader, | ||||
|         ckpt_path=resume_from_ckpt_path, | ||||
|     ) | ||||
		Loading…
	
		Reference in New Issue