Witllm/wit/qwen_generation_utils.py

110 lines
3.6 KiB
Python

# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Generation support."""
from typing import Tuple, List, Union, Iterable
import numpy as np
import torch
import torch.nn.functional as F
from transformers import PreTrainedTokenizer
from transformers import logging
from transformers.generation import LogitsProcessor
logger = logging.get_logger(__name__)
# Types.
HistoryType = List[Tuple[str, str]]
TokensType = List[int]
BatchTokensType = List[List[int]]
def make_context(
tokenizer: PreTrainedTokenizer,
query: str,
query_assistant: str = "",
history: List[Tuple[str, str]] = None,
system: str = "",
max_window_size: int = 6144,
):
if history is None:
history = []
im_start, im_end = "<|im_start|>", "<|im_end|>"
im_start_tokens = [tokenizer.im_start_id]
im_end_tokens = [tokenizer.im_end_id]
nl_tokens = tokenizer.encode("\n")
def _tokenize_str(role, content):
return f"{role}\n{content}", tokenizer.encode(role, allowed_special=set()) + nl_tokens + tokenizer.encode(
content, allowed_special=set()
)
system_text, system_tokens_part = _tokenize_str("system", system)
system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
assistant_tokens = tokenizer.encode(query_assistant, allowed_special=set())
raw_text = ""
context_tokens = []
for turn_query, turn_response in reversed(history):
query_text, query_tokens_part = _tokenize_str("user", turn_query)
query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
response_text, response_tokens_part = _tokenize_str("assistant", turn_response)
response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
prev_chat = f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
current_context_size = len(system_tokens) + len(next_context_tokens) + len(context_tokens)
if current_context_size < max_window_size:
context_tokens = next_context_tokens + context_tokens
raw_text = prev_chat + raw_text
else:
break
context_tokens = system_tokens + context_tokens
raw_text = f"{im_start}{system_text}{im_end}" + raw_text
context_tokens += (
nl_tokens
+ im_start_tokens
+ _tokenize_str("user", query)[1]
+ im_end_tokens
+ nl_tokens
+ im_start_tokens
+ tokenizer.encode("assistant")
+ nl_tokens
+ assistant_tokens
)
raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n{query_assistant}"
return raw_text, context_tokens
def decode_tokens(
tokens: Union[torch.LongTensor, TokensType],
tokenizer: PreTrainedTokenizer,
raw_text_len: int = 0,
context_length: int = 0,
errors: str = "replace",
) -> str:
if torch.is_tensor(tokens):
tokens = tokens.cpu().numpy().tolist()
end_reason = f"Gen length {len(tokens)}"
eod_token_idx = context_length
for eod_token_idx in range(context_length, len(tokens)):
if tokens[eod_token_idx] in [tokenizer.im_start_id, tokenizer.im_end_id]:
end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}"
break
decoded = tokenizer.decode(tokens, errors=errors)
decode_tokens = tokenizer.decode(tokens[:eod_token_idx], errors=errors)
trim_decode_tokens = decode_tokens[raw_text_len:]
trim_decode_tokens = trim_decode_tokens.strip()
return decoded, trim_decode_tokens, end_reason