refine code.
This commit is contained in:
		
							parent
							
								
									c92db47135
								
							
						
					
					
						commit
						3feec36059
					
				| 
						 | 
					@ -41,9 +41,11 @@ def InitDataset(config):
 | 
				
			||||||
            os.mkdir(path)
 | 
					            os.mkdir(path)
 | 
				
			||||||
        if os.path.exists(trainfile) and os.path.exists(valfile):
 | 
					        if os.path.exists(trainfile) and os.path.exists(valfile):
 | 
				
			||||||
            print(f"INFO: Load dataset from {trainfile}")
 | 
					            print(f"INFO: Load dataset from {trainfile}")
 | 
				
			||||||
 | 
					            train_dataset = torch.load(trainfile, weights_only=False)
 | 
				
			||||||
 | 
					            train_dataset.set_mask(conf.mask_level, conf.mask_idx)
 | 
				
			||||||
            print(f"INFO: Load dataset from {valfile}")
 | 
					            print(f"INFO: Load dataset from {valfile}")
 | 
				
			||||||
            train_dataset = torch.load(trainfile)
 | 
					            val_dataset = torch.load(valfile, weights_only=False)
 | 
				
			||||||
            val_dataset = torch.load(valfile)
 | 
					            val_dataset.set_mask(conf.mask_level, conf.mask_idx)
 | 
				
			||||||
            print(f"INFO: Load dataset end")
 | 
					            print(f"INFO: Load dataset end")
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            raw_dataset = MeaningDataset(start, start + size, vocab, None, conf.level_ratio, conf.min_subitem)
 | 
					            raw_dataset = MeaningDataset(start, start + size, vocab, None, conf.level_ratio, conf.min_subitem)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -25,7 +25,7 @@ model = QWenLMHeadModel(config)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
print(model)
 | 
					print(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
tokenizer = QWenTokenizer("./wit_b64.tiktoken", "./wit_char.tiktoken")
 | 
					tokenizer = QWenTokenizer("./model/wit_b64.tiktoken", "./model/wit_char.tiktoken")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
sys.path.append("..")
 | 
					sys.path.append("..")
 | 
				
			||||||
from tools import show
 | 
					from tools import show
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -39,7 +39,7 @@ if __name__ == "__main__":
 | 
				
			||||||
    config.num_attention_heads = 16  # 8 8 16
 | 
					    config.num_attention_heads = 16  # 8 8 16
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    lit_module = LitModule(pretrain_model_name, learning_rate, config, use_tril_attention_mask)
 | 
					    lit_module = LitModule(pretrain_model_name, learning_rate, config, use_tril_attention_mask)
 | 
				
			||||||
    tokenizer = QWenTokenizer("./wit_b64.tiktoken", "./wit_char.tiktoken")
 | 
					    tokenizer = QWenTokenizer("./model/wit_b64.tiktoken", "./model/wit_char.tiktoken")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    level_ratio = 2
 | 
					    level_ratio = 2
 | 
				
			||||||
    start = vocab_size * level_ratio * level_ratio
 | 
					    start = vocab_size * level_ratio * level_ratio
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -6,22 +6,20 @@ import torch
 | 
				
			||||||
import torchmetrics
 | 
					import torchmetrics
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from model.modeling_wit import QWenLMHeadModel
 | 
					from model.modeling_wit import QWenLMHeadModel
 | 
				
			||||||
from configuration import ModelConfig
 | 
					from configuration import ModelConfig, TrainConfig
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class LitModule(pl.LightningModule):
 | 
					class LitModule(pl.LightningModule):
 | 
				
			||||||
    def __init__(
 | 
					    def __init__(self, conf: TrainConfig = None):
 | 
				
			||||||
        self,
 | 
					        pretrained_model_dir = conf.pretrain_model_name
 | 
				
			||||||
        pretrained_model_dir: str = None,
 | 
					        learning_rate = conf.learning_rate
 | 
				
			||||||
        learning_rate: float = 0.0001,
 | 
					        mconf = conf.model_config
 | 
				
			||||||
        config: ModelConfig = None,
 | 
					        use_tril_attention_mask = conf.use_tril_attention_mask
 | 
				
			||||||
        use_tril_attention_mask: str = False,
 | 
					 | 
				
			||||||
    ):
 | 
					 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
        self.save_hyperparameters()
 | 
					        self.save_hyperparameters()
 | 
				
			||||||
        if config == None:
 | 
					        if mconf == None:
 | 
				
			||||||
            config = ModelConfig()
 | 
					            mconf = ModelConfig()
 | 
				
			||||||
        model = QWenLMHeadModel(config)
 | 
					        model = QWenLMHeadModel(mconf)
 | 
				
			||||||
        if pretrained_model_dir != None:
 | 
					        if pretrained_model_dir != None:
 | 
				
			||||||
            from modelscope import snapshot_download
 | 
					            from modelscope import snapshot_download
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,294 +0,0 @@
 | 
				
			||||||
# Copyright (c) Alibaba Cloud.
 | 
					 | 
				
			||||||
#
 | 
					 | 
				
			||||||
# This source code is licensed under the license found in the
 | 
					 | 
				
			||||||
# LICENSE file in the root directory of this source tree.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
"""Generation support."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from typing import Tuple, List, Union, Iterable
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import numpy as np
 | 
					 | 
				
			||||||
import torch
 | 
					 | 
				
			||||||
import torch.nn.functional as F
 | 
					 | 
				
			||||||
from transformers import PreTrainedTokenizer
 | 
					 | 
				
			||||||
from transformers import logging
 | 
					 | 
				
			||||||
from transformers.generation import LogitsProcessor
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
logger = logging.get_logger(__name__)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Types.
 | 
					 | 
				
			||||||
HistoryType = List[Tuple[str, str]]
 | 
					 | 
				
			||||||
TokensType = List[int]
 | 
					 | 
				
			||||||
BatchTokensType = List[List[int]]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def pad_batch(batch: BatchTokensType, pad_id: int, seq_length: int) -> BatchTokensType:
 | 
					 | 
				
			||||||
    for tokens in batch:
 | 
					 | 
				
			||||||
        context_length = len(tokens)
 | 
					 | 
				
			||||||
        if context_length < seq_length:
 | 
					 | 
				
			||||||
            tokens.extend([pad_id] * (seq_length - context_length))
 | 
					 | 
				
			||||||
    return batch
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def get_ltor_masks_and_position_ids(
 | 
					 | 
				
			||||||
    data,
 | 
					 | 
				
			||||||
    eod_token,
 | 
					 | 
				
			||||||
    reset_position_ids,
 | 
					 | 
				
			||||||
    reset_attention_mask,
 | 
					 | 
				
			||||||
    eod_mask_loss,
 | 
					 | 
				
			||||||
):
 | 
					 | 
				
			||||||
    """Build masks and position id for left to right model."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Extract batch size and sequence length.
 | 
					 | 
				
			||||||
    micro_batch_size, seq_length = data.size()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Attention mask (lower triangular).
 | 
					 | 
				
			||||||
    if reset_attention_mask:
 | 
					 | 
				
			||||||
        att_mask_batch = micro_batch_size
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        att_mask_batch = 1
 | 
					 | 
				
			||||||
    attention_mask = torch.tril(torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)).view(
 | 
					 | 
				
			||||||
        att_mask_batch, 1, seq_length, seq_length
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Loss mask.
 | 
					 | 
				
			||||||
    loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
 | 
					 | 
				
			||||||
    if eod_mask_loss:
 | 
					 | 
				
			||||||
        loss_mask[data == eod_token] = 0.0
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Position ids.
 | 
					 | 
				
			||||||
    position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
 | 
					 | 
				
			||||||
    position_ids = position_ids.unsqueeze(0).expand_as(data)
 | 
					 | 
				
			||||||
    # We need to clone as the ids will be modifed based on batch index.
 | 
					 | 
				
			||||||
    if reset_position_ids:
 | 
					 | 
				
			||||||
        position_ids = position_ids.clone()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if reset_position_ids or reset_attention_mask:
 | 
					 | 
				
			||||||
        # Loop through the batches:
 | 
					 | 
				
			||||||
        for b in range(micro_batch_size):
 | 
					 | 
				
			||||||
            # Find indecies where EOD token is.
 | 
					 | 
				
			||||||
            eod_index = position_ids[b, data[b] == eod_token]
 | 
					 | 
				
			||||||
            # Detach indecies from positions if going to modify positions.
 | 
					 | 
				
			||||||
            if reset_position_ids:
 | 
					 | 
				
			||||||
                eod_index = eod_index.clone()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # Loop through EOD indecies:
 | 
					 | 
				
			||||||
            prev_index = 0
 | 
					 | 
				
			||||||
            for j in range(eod_index.size()[0]):
 | 
					 | 
				
			||||||
                i = eod_index[j]
 | 
					 | 
				
			||||||
                # Mask attention loss.
 | 
					 | 
				
			||||||
                if reset_attention_mask:
 | 
					 | 
				
			||||||
                    attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0
 | 
					 | 
				
			||||||
                # Reset positions.
 | 
					 | 
				
			||||||
                if reset_position_ids:
 | 
					 | 
				
			||||||
                    position_ids[b, (i + 1) :] -= i + 1 - prev_index
 | 
					 | 
				
			||||||
                    prev_index = i + 1
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Convert attention mask to binary:
 | 
					 | 
				
			||||||
    attention_mask = attention_mask < 0.5
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return attention_mask, loss_mask, position_ids
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def get_batch(context_tokens: torch.LongTensor, eod_id: int):
 | 
					 | 
				
			||||||
    """Generate batch from context tokens."""
 | 
					 | 
				
			||||||
    # Move to GPU.
 | 
					 | 
				
			||||||
    tokens = context_tokens.contiguous().to(context_tokens.device)
 | 
					 | 
				
			||||||
    # Get the attention mask and postition ids.
 | 
					 | 
				
			||||||
    attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
 | 
					 | 
				
			||||||
        tokens,
 | 
					 | 
				
			||||||
        eod_id,
 | 
					 | 
				
			||||||
        reset_position_ids=False,
 | 
					 | 
				
			||||||
        reset_attention_mask=False,
 | 
					 | 
				
			||||||
        eod_mask_loss=False,
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    return tokens, attention_mask, position_ids
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def make_context(
 | 
					 | 
				
			||||||
    tokenizer: PreTrainedTokenizer,
 | 
					 | 
				
			||||||
    query: str,
 | 
					 | 
				
			||||||
    query_assistant: str = "",
 | 
					 | 
				
			||||||
    history: List[Tuple[str, str]] = None,
 | 
					 | 
				
			||||||
    system: str = "",
 | 
					 | 
				
			||||||
    max_window_size: int = 6144,
 | 
					 | 
				
			||||||
):
 | 
					 | 
				
			||||||
    if history is None:
 | 
					 | 
				
			||||||
        history = []
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    im_start, im_end = "<|im_start|>", "<|im_end|>"
 | 
					 | 
				
			||||||
    im_start_tokens = [tokenizer.im_start_id]
 | 
					 | 
				
			||||||
    im_end_tokens = [tokenizer.im_end_id]
 | 
					 | 
				
			||||||
    nl_tokens = tokenizer.encode("\n")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _tokenize_str(role, content):
 | 
					 | 
				
			||||||
        return f"{role}\n{content}", tokenizer.encode(role, allowed_special=set()) + nl_tokens + tokenizer.encode(
 | 
					 | 
				
			||||||
            content, allowed_special=set()
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    system_text, system_tokens_part = _tokenize_str("system", system)
 | 
					 | 
				
			||||||
    system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
 | 
					 | 
				
			||||||
    assistant_tokens = tokenizer.encode(query_assistant, allowed_special=set())
 | 
					 | 
				
			||||||
    raw_text = ""
 | 
					 | 
				
			||||||
    context_tokens = []
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    for turn_query, turn_response in reversed(history):
 | 
					 | 
				
			||||||
        query_text, query_tokens_part = _tokenize_str("user", turn_query)
 | 
					 | 
				
			||||||
        query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
 | 
					 | 
				
			||||||
        response_text, response_tokens_part = _tokenize_str("assistant", turn_response)
 | 
					 | 
				
			||||||
        response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
 | 
					 | 
				
			||||||
        prev_chat = f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        current_context_size = len(system_tokens) + len(next_context_tokens) + len(context_tokens)
 | 
					 | 
				
			||||||
        if current_context_size < max_window_size:
 | 
					 | 
				
			||||||
            context_tokens = next_context_tokens + context_tokens
 | 
					 | 
				
			||||||
            raw_text = prev_chat + raw_text
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            break
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    context_tokens = system_tokens + context_tokens
 | 
					 | 
				
			||||||
    raw_text = f"{im_start}{system_text}{im_end}" + raw_text
 | 
					 | 
				
			||||||
    context_tokens += (
 | 
					 | 
				
			||||||
        nl_tokens
 | 
					 | 
				
			||||||
        + im_start_tokens
 | 
					 | 
				
			||||||
        + _tokenize_str("user", query)[1]
 | 
					 | 
				
			||||||
        + im_end_tokens
 | 
					 | 
				
			||||||
        + nl_tokens
 | 
					 | 
				
			||||||
        + im_start_tokens
 | 
					 | 
				
			||||||
        + tokenizer.encode("assistant")
 | 
					 | 
				
			||||||
        + nl_tokens
 | 
					 | 
				
			||||||
        + assistant_tokens
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n{query_assistant}"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return raw_text, context_tokens
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def decode_tokens(
 | 
					 | 
				
			||||||
    tokens: Union[torch.LongTensor, TokensType],
 | 
					 | 
				
			||||||
    tokenizer: PreTrainedTokenizer,
 | 
					 | 
				
			||||||
    raw_text_len: int = 0,
 | 
					 | 
				
			||||||
    context_length: int = 0,
 | 
					 | 
				
			||||||
    errors: str = "replace",
 | 
					 | 
				
			||||||
) -> str:
 | 
					 | 
				
			||||||
    if torch.is_tensor(tokens):
 | 
					 | 
				
			||||||
        tokens = tokens.cpu().numpy().tolist()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    end_reason = f"Gen length {len(tokens)}"
 | 
					 | 
				
			||||||
    eod_token_idx = context_length
 | 
					 | 
				
			||||||
    for eod_token_idx in range(context_length, len(tokens)):
 | 
					 | 
				
			||||||
        if tokens[eod_token_idx] in [tokenizer.im_start_id, tokenizer.im_end_id]:
 | 
					 | 
				
			||||||
            end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}"
 | 
					 | 
				
			||||||
            break
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    decoded = tokenizer.decode(tokens, errors=errors)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    decode_tokens = tokenizer.decode(tokens[:eod_token_idx], errors=errors)
 | 
					 | 
				
			||||||
    trim_decode_tokens = decode_tokens[raw_text_len:]
 | 
					 | 
				
			||||||
    trim_decode_tokens = trim_decode_tokens.strip()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return decoded, trim_decode_tokens, end_reason
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class StopWordsLogitsProcessor(LogitsProcessor):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    :class:`transformers.LogitsProcessor` that enforces that when specified sequences appear, stop geration.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        stop_words_ids (:obj:`List[List[int]]`):
 | 
					 | 
				
			||||||
            List of list of token ids of stop ids. In order to get the tokens of the words
 | 
					 | 
				
			||||||
            that should not appear in the generated text, use :obj:`tokenizer(bad_word,
 | 
					 | 
				
			||||||
            add_prefix_space=True).input_ids`.
 | 
					 | 
				
			||||||
        eos_token_id (:obj:`int`):
 | 
					 | 
				
			||||||
            The id of the `end-of-sequence` token.
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int):
 | 
					 | 
				
			||||||
        if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0:
 | 
					 | 
				
			||||||
            raise ValueError(f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}.")
 | 
					 | 
				
			||||||
        if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids):
 | 
					 | 
				
			||||||
            raise ValueError(f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}.")
 | 
					 | 
				
			||||||
        if any(
 | 
					 | 
				
			||||||
            any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in stop_word_ids)
 | 
					 | 
				
			||||||
            for stop_word_ids in stop_words_ids
 | 
					 | 
				
			||||||
        ):
 | 
					 | 
				
			||||||
            raise ValueError(
 | 
					 | 
				
			||||||
                f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}."
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.stop_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids))
 | 
					 | 
				
			||||||
        self.eos_token_id = eos_token_id
 | 
					 | 
				
			||||||
        for stop_token_seq in self.stop_words_ids:
 | 
					 | 
				
			||||||
            assert len(stop_token_seq) > 0, "Stop words token sequences {} cannot have an empty list".format(
 | 
					 | 
				
			||||||
                stop_words_ids
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
 | 
					 | 
				
			||||||
        stopped_samples = self._calc_stopped_samples(input_ids)
 | 
					 | 
				
			||||||
        for i, should_stop in enumerate(stopped_samples):
 | 
					 | 
				
			||||||
            if should_stop:
 | 
					 | 
				
			||||||
                scores[i, self.eos_token_id] = float(2**15)
 | 
					 | 
				
			||||||
        return scores
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool:
 | 
					 | 
				
			||||||
        if len(tokens) == 0:
 | 
					 | 
				
			||||||
            # if bad word tokens is just one token always ban it
 | 
					 | 
				
			||||||
            return True
 | 
					 | 
				
			||||||
        elif len(tokens) > len(prev_tokens):
 | 
					 | 
				
			||||||
            # if bad word tokens are longer then prev input_ids they can't be equal
 | 
					 | 
				
			||||||
            return False
 | 
					 | 
				
			||||||
        elif prev_tokens[-len(tokens) :].tolist() == tokens:
 | 
					 | 
				
			||||||
            # if tokens match
 | 
					 | 
				
			||||||
            return True
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            return False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _calc_stopped_samples(self, prev_input_ids: Iterable[int]) -> Iterable[int]:
 | 
					 | 
				
			||||||
        stopped_samples = []
 | 
					 | 
				
			||||||
        for prev_input_ids_slice in prev_input_ids:
 | 
					 | 
				
			||||||
            match = False
 | 
					 | 
				
			||||||
            for stop_token_seq in self.stop_words_ids:
 | 
					 | 
				
			||||||
                if self._tokens_match(prev_input_ids_slice, stop_token_seq):
 | 
					 | 
				
			||||||
                    # if tokens do not match continue
 | 
					 | 
				
			||||||
                    match = True
 | 
					 | 
				
			||||||
                    break
 | 
					 | 
				
			||||||
            stopped_samples.append(match)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return stopped_samples
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
 | 
					 | 
				
			||||||
    """This function has been mostly taken from huggingface conversational
 | 
					 | 
				
			||||||
    ai code at
 | 
					 | 
				
			||||||
        https://medium.com/huggingface/how-to-build-a-state-of-the-art-
 | 
					 | 
				
			||||||
             conversational-ai-with-transfer-learning-2d818ac26313"""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if top_k > 0:
 | 
					 | 
				
			||||||
        # Remove all tokens with a probability less than the
 | 
					 | 
				
			||||||
        # last token of the top-k
 | 
					 | 
				
			||||||
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
 | 
					 | 
				
			||||||
        logits[indices_to_remove] = filter_value
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if top_p > 0.0:
 | 
					 | 
				
			||||||
        # Cconvert to 1D
 | 
					 | 
				
			||||||
        sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
 | 
					 | 
				
			||||||
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Remove tokens with cumulative probability above the threshold
 | 
					 | 
				
			||||||
        sorted_indices_to_remove = cumulative_probs > top_p
 | 
					 | 
				
			||||||
        # Shift the indices to the right to keep also the first token
 | 
					 | 
				
			||||||
        # above the threshold
 | 
					 | 
				
			||||||
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
 | 
					 | 
				
			||||||
        sorted_indices_to_remove[..., 0] = 0
 | 
					 | 
				
			||||||
        for i in range(sorted_indices.size(0)):
 | 
					 | 
				
			||||||
            indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
 | 
					 | 
				
			||||||
            logits[i][indices_to_remove] = filter_value
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return logits
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def switch(val1, val2, boolean):
 | 
					 | 
				
			||||||
    boolean = boolean.type_as(val1)
 | 
					 | 
				
			||||||
    return (1 - boolean) * val1 + boolean * val2
 | 
					 | 
				
			||||||
		Loading…
	
		Reference in New Issue