# Copyright (c) Alibaba Cloud. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Generation support.""" from typing import Tuple, List, Union, Iterable import numpy as np import torch import torch.nn.functional as F from transformers import PreTrainedTokenizer from transformers import logging from transformers.generation import LogitsProcessor logger = logging.get_logger(__name__) # Types. HistoryType = List[Tuple[str, str]] TokensType = List[int] BatchTokensType = List[List[int]] def make_context( tokenizer: PreTrainedTokenizer, query: str, query_assistant: str = "", history: List[Tuple[str, str]] = None, system: str = "", max_window_size: int = 6144, ): if history is None: history = [] im_start, im_end = "<|im_start|>", "<|im_end|>" im_start_tokens = [tokenizer.im_start_id] im_end_tokens = [tokenizer.im_end_id] nl_tokens = tokenizer.encode("\n") def _tokenize_str(role, content): return f"{role}\n{content}", tokenizer.encode(role, allowed_special=set()) + nl_tokens + tokenizer.encode( content, allowed_special=set() ) system_text, system_tokens_part = _tokenize_str("system", system) system_tokens = im_start_tokens + system_tokens_part + im_end_tokens assistant_tokens = tokenizer.encode(query_assistant, allowed_special=set()) raw_text = "" context_tokens = [] for turn_query, turn_response in reversed(history): query_text, query_tokens_part = _tokenize_str("user", turn_query) query_tokens = im_start_tokens + query_tokens_part + im_end_tokens response_text, response_tokens_part = _tokenize_str("assistant", turn_response) response_tokens = im_start_tokens + response_tokens_part + im_end_tokens next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens prev_chat = f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}" current_context_size = len(system_tokens) + len(next_context_tokens) + len(context_tokens) if current_context_size < max_window_size: context_tokens = next_context_tokens + context_tokens raw_text = prev_chat + raw_text else: break context_tokens = system_tokens + context_tokens raw_text = f"{im_start}{system_text}{im_end}" + raw_text context_tokens += ( nl_tokens + im_start_tokens + _tokenize_str("user", query)[1] + im_end_tokens + nl_tokens + im_start_tokens + tokenizer.encode("assistant") + nl_tokens + assistant_tokens ) raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n{query_assistant}" return raw_text, context_tokens def decode_tokens( tokens: Union[torch.LongTensor, TokensType], tokenizer: PreTrainedTokenizer, raw_text_len: int = 0, context_length: int = 0, errors: str = "replace", ) -> str: if torch.is_tensor(tokens): tokens = tokens.cpu().numpy().tolist() end_reason = f"Gen length {len(tokens)}" eod_token_idx = context_length for eod_token_idx in range(context_length, len(tokens)): if tokens[eod_token_idx] in [tokenizer.im_start_id, tokenizer.im_end_id]: end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}" break decoded = tokenizer.decode(tokens, errors=errors) decode_tokens = tokenizer.decode(tokens[:eod_token_idx], errors=errors) trim_decode_tokens = decode_tokens[raw_text_len:] trim_decode_tokens = trim_decode_tokens.strip() return decoded, trim_decode_tokens, end_reason