diff --git a/wit/dataset/dataset.py b/wit/dataset/dataset.py index 1c9a5ee..7431610 100644 --- a/wit/dataset/dataset.py +++ b/wit/dataset/dataset.py @@ -41,9 +41,11 @@ def InitDataset(config): os.mkdir(path) if os.path.exists(trainfile) and os.path.exists(valfile): print(f"INFO: Load dataset from {trainfile}") + train_dataset = torch.load(trainfile, weights_only=False) + train_dataset.set_mask(conf.mask_level, conf.mask_idx) print(f"INFO: Load dataset from {valfile}") - train_dataset = torch.load(trainfile) - val_dataset = torch.load(valfile) + val_dataset = torch.load(valfile, weights_only=False) + val_dataset.set_mask(conf.mask_level, conf.mask_idx) print(f"INFO: Load dataset end") else: raw_dataset = MeaningDataset(start, start + size, vocab, None, conf.level_ratio, conf.min_subitem) diff --git a/wit/demo.py b/wit/demo.py index d17e48f..a5cd2b9 100644 --- a/wit/demo.py +++ b/wit/demo.py @@ -25,7 +25,7 @@ model = QWenLMHeadModel(config) print(model) -tokenizer = QWenTokenizer("./wit_b64.tiktoken", "./wit_char.tiktoken") +tokenizer = QWenTokenizer("./model/wit_b64.tiktoken", "./model/wit_char.tiktoken") sys.path.append("..") from tools import show diff --git a/wit/inference.py b/wit/inference.py index b947771..440b280 100644 --- a/wit/inference.py +++ b/wit/inference.py @@ -39,7 +39,7 @@ if __name__ == "__main__": config.num_attention_heads = 16 # 8 8 16 lit_module = LitModule(pretrain_model_name, learning_rate, config, use_tril_attention_mask) - tokenizer = QWenTokenizer("./wit_b64.tiktoken", "./wit_char.tiktoken") + tokenizer = QWenTokenizer("./model/wit_b64.tiktoken", "./model/wit_char.tiktoken") level_ratio = 2 start = vocab_size * level_ratio * level_ratio diff --git a/wit/model/lit_module.py b/wit/model/lit_module.py index f654c4b..9c5a4f1 100644 --- a/wit/model/lit_module.py +++ b/wit/model/lit_module.py @@ -6,22 +6,20 @@ import torch import torchmetrics from model.modeling_wit import QWenLMHeadModel -from configuration import ModelConfig +from configuration import ModelConfig, TrainConfig class LitModule(pl.LightningModule): - def __init__( - self, - pretrained_model_dir: str = None, - learning_rate: float = 0.0001, - config: ModelConfig = None, - use_tril_attention_mask: str = False, - ): + def __init__(self, conf: TrainConfig = None): + pretrained_model_dir = conf.pretrain_model_name + learning_rate = conf.learning_rate + mconf = conf.model_config + use_tril_attention_mask = conf.use_tril_attention_mask super().__init__() self.save_hyperparameters() - if config == None: - config = ModelConfig() - model = QWenLMHeadModel(config) + if mconf == None: + mconf = ModelConfig() + model = QWenLMHeadModel(mconf) if pretrained_model_dir != None: from modelscope import snapshot_download diff --git a/wit/model/qwen_generation_utils copy.py b/wit/model/qwen_generation_utils copy.py deleted file mode 100644 index cc80126..0000000 --- a/wit/model/qwen_generation_utils copy.py +++ /dev/null @@ -1,294 +0,0 @@ -# Copyright (c) Alibaba Cloud. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -"""Generation support.""" - -from typing import Tuple, List, Union, Iterable - -import numpy as np -import torch -import torch.nn.functional as F -from transformers import PreTrainedTokenizer -from transformers import logging -from transformers.generation import LogitsProcessor - -logger = logging.get_logger(__name__) - -# Types. -HistoryType = List[Tuple[str, str]] -TokensType = List[int] -BatchTokensType = List[List[int]] - - -def pad_batch(batch: BatchTokensType, pad_id: int, seq_length: int) -> BatchTokensType: - for tokens in batch: - context_length = len(tokens) - if context_length < seq_length: - tokens.extend([pad_id] * (seq_length - context_length)) - return batch - - -def get_ltor_masks_and_position_ids( - data, - eod_token, - reset_position_ids, - reset_attention_mask, - eod_mask_loss, -): - """Build masks and position id for left to right model.""" - - # Extract batch size and sequence length. - micro_batch_size, seq_length = data.size() - - # Attention mask (lower triangular). - if reset_attention_mask: - att_mask_batch = micro_batch_size - else: - att_mask_batch = 1 - attention_mask = torch.tril(torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)).view( - att_mask_batch, 1, seq_length, seq_length - ) - - # Loss mask. - loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device) - if eod_mask_loss: - loss_mask[data == eod_token] = 0.0 - - # Position ids. - position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device) - position_ids = position_ids.unsqueeze(0).expand_as(data) - # We need to clone as the ids will be modifed based on batch index. - if reset_position_ids: - position_ids = position_ids.clone() - - if reset_position_ids or reset_attention_mask: - # Loop through the batches: - for b in range(micro_batch_size): - # Find indecies where EOD token is. - eod_index = position_ids[b, data[b] == eod_token] - # Detach indecies from positions if going to modify positions. - if reset_position_ids: - eod_index = eod_index.clone() - - # Loop through EOD indecies: - prev_index = 0 - for j in range(eod_index.size()[0]): - i = eod_index[j] - # Mask attention loss. - if reset_attention_mask: - attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0 - # Reset positions. - if reset_position_ids: - position_ids[b, (i + 1) :] -= i + 1 - prev_index - prev_index = i + 1 - - # Convert attention mask to binary: - attention_mask = attention_mask < 0.5 - - return attention_mask, loss_mask, position_ids - - -def get_batch(context_tokens: torch.LongTensor, eod_id: int): - """Generate batch from context tokens.""" - # Move to GPU. - tokens = context_tokens.contiguous().to(context_tokens.device) - # Get the attention mask and postition ids. - attention_mask, _, position_ids = get_ltor_masks_and_position_ids( - tokens, - eod_id, - reset_position_ids=False, - reset_attention_mask=False, - eod_mask_loss=False, - ) - return tokens, attention_mask, position_ids - - -def make_context( - tokenizer: PreTrainedTokenizer, - query: str, - query_assistant: str = "", - history: List[Tuple[str, str]] = None, - system: str = "", - max_window_size: int = 6144, -): - if history is None: - history = [] - - im_start, im_end = "<|im_start|>", "<|im_end|>" - im_start_tokens = [tokenizer.im_start_id] - im_end_tokens = [tokenizer.im_end_id] - nl_tokens = tokenizer.encode("\n") - - def _tokenize_str(role, content): - return f"{role}\n{content}", tokenizer.encode(role, allowed_special=set()) + nl_tokens + tokenizer.encode( - content, allowed_special=set() - ) - - system_text, system_tokens_part = _tokenize_str("system", system) - system_tokens = im_start_tokens + system_tokens_part + im_end_tokens - assistant_tokens = tokenizer.encode(query_assistant, allowed_special=set()) - raw_text = "" - context_tokens = [] - - for turn_query, turn_response in reversed(history): - query_text, query_tokens_part = _tokenize_str("user", turn_query) - query_tokens = im_start_tokens + query_tokens_part + im_end_tokens - response_text, response_tokens_part = _tokenize_str("assistant", turn_response) - response_tokens = im_start_tokens + response_tokens_part + im_end_tokens - - next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens - prev_chat = f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}" - - current_context_size = len(system_tokens) + len(next_context_tokens) + len(context_tokens) - if current_context_size < max_window_size: - context_tokens = next_context_tokens + context_tokens - raw_text = prev_chat + raw_text - else: - break - - context_tokens = system_tokens + context_tokens - raw_text = f"{im_start}{system_text}{im_end}" + raw_text - context_tokens += ( - nl_tokens - + im_start_tokens - + _tokenize_str("user", query)[1] - + im_end_tokens - + nl_tokens - + im_start_tokens - + tokenizer.encode("assistant") - + nl_tokens - + assistant_tokens - ) - raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n{query_assistant}" - - return raw_text, context_tokens - - -def decode_tokens( - tokens: Union[torch.LongTensor, TokensType], - tokenizer: PreTrainedTokenizer, - raw_text_len: int = 0, - context_length: int = 0, - errors: str = "replace", -) -> str: - if torch.is_tensor(tokens): - tokens = tokens.cpu().numpy().tolist() - - end_reason = f"Gen length {len(tokens)}" - eod_token_idx = context_length - for eod_token_idx in range(context_length, len(tokens)): - if tokens[eod_token_idx] in [tokenizer.im_start_id, tokenizer.im_end_id]: - end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}" - break - - decoded = tokenizer.decode(tokens, errors=errors) - - decode_tokens = tokenizer.decode(tokens[:eod_token_idx], errors=errors) - trim_decode_tokens = decode_tokens[raw_text_len:] - trim_decode_tokens = trim_decode_tokens.strip() - - return decoded, trim_decode_tokens, end_reason - - -class StopWordsLogitsProcessor(LogitsProcessor): - """ - :class:`transformers.LogitsProcessor` that enforces that when specified sequences appear, stop geration. - - Args: - stop_words_ids (:obj:`List[List[int]]`): - List of list of token ids of stop ids. In order to get the tokens of the words - that should not appear in the generated text, use :obj:`tokenizer(bad_word, - add_prefix_space=True).input_ids`. - eos_token_id (:obj:`int`): - The id of the `end-of-sequence` token. - """ - - def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int): - if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0: - raise ValueError(f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}.") - if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids): - raise ValueError(f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}.") - if any( - any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in stop_word_ids) - for stop_word_ids in stop_words_ids - ): - raise ValueError( - f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}." - ) - - self.stop_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids)) - self.eos_token_id = eos_token_id - for stop_token_seq in self.stop_words_ids: - assert len(stop_token_seq) > 0, "Stop words token sequences {} cannot have an empty list".format( - stop_words_ids - ) - - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - stopped_samples = self._calc_stopped_samples(input_ids) - for i, should_stop in enumerate(stopped_samples): - if should_stop: - scores[i, self.eos_token_id] = float(2**15) - return scores - - def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool: - if len(tokens) == 0: - # if bad word tokens is just one token always ban it - return True - elif len(tokens) > len(prev_tokens): - # if bad word tokens are longer then prev input_ids they can't be equal - return False - elif prev_tokens[-len(tokens) :].tolist() == tokens: - # if tokens match - return True - else: - return False - - def _calc_stopped_samples(self, prev_input_ids: Iterable[int]) -> Iterable[int]: - stopped_samples = [] - for prev_input_ids_slice in prev_input_ids: - match = False - for stop_token_seq in self.stop_words_ids: - if self._tokens_match(prev_input_ids_slice, stop_token_seq): - # if tokens do not match continue - match = True - break - stopped_samples.append(match) - - return stopped_samples - - -def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")): - """This function has been mostly taken from huggingface conversational - ai code at - https://medium.com/huggingface/how-to-build-a-state-of-the-art- - conversational-ai-with-transfer-learning-2d818ac26313""" - - if top_k > 0: - # Remove all tokens with a probability less than the - # last token of the top-k - indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] - logits[indices_to_remove] = filter_value - - if top_p > 0.0: - # Cconvert to 1D - sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) - cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) - - # Remove tokens with cumulative probability above the threshold - sorted_indices_to_remove = cumulative_probs > top_p - # Shift the indices to the right to keep also the first token - # above the threshold - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() - sorted_indices_to_remove[..., 0] = 0 - for i in range(sorted_indices.size(0)): - indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]] - logits[i][indices_to_remove] = filter_value - - return logits - - -def switch(val1, val2, boolean): - boolean = boolean.type_as(val1) - return (1 - boolean) * val1 + boolean * val2