110 lines
3.6 KiB
Python
110 lines
3.6 KiB
Python
|
# Copyright (c) Alibaba Cloud.
|
||
|
#
|
||
|
# This source code is licensed under the license found in the
|
||
|
# LICENSE file in the root directory of this source tree.
|
||
|
|
||
|
"""Generation support."""
|
||
|
|
||
|
from typing import Tuple, List, Union, Iterable
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
import torch.nn.functional as F
|
||
|
from transformers import PreTrainedTokenizer
|
||
|
from transformers import logging
|
||
|
from transformers.generation import LogitsProcessor
|
||
|
|
||
|
logger = logging.get_logger(__name__)
|
||
|
|
||
|
# Types.
|
||
|
HistoryType = List[Tuple[str, str]]
|
||
|
TokensType = List[int]
|
||
|
BatchTokensType = List[List[int]]
|
||
|
|
||
|
|
||
|
def make_context(
|
||
|
tokenizer: PreTrainedTokenizer,
|
||
|
query: str,
|
||
|
query_assistant: str = "",
|
||
|
history: List[Tuple[str, str]] = None,
|
||
|
system: str = "",
|
||
|
max_window_size: int = 6144,
|
||
|
):
|
||
|
if history is None:
|
||
|
history = []
|
||
|
|
||
|
im_start, im_end = "<|im_start|>", "<|im_end|>"
|
||
|
im_start_tokens = [tokenizer.im_start_id]
|
||
|
im_end_tokens = [tokenizer.im_end_id]
|
||
|
nl_tokens = tokenizer.encode("\n")
|
||
|
|
||
|
def _tokenize_str(role, content):
|
||
|
return f"{role}\n{content}", tokenizer.encode(role, allowed_special=set()) + nl_tokens + tokenizer.encode(
|
||
|
content, allowed_special=set()
|
||
|
)
|
||
|
|
||
|
system_text, system_tokens_part = _tokenize_str("system", system)
|
||
|
system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
|
||
|
assistant_tokens = tokenizer.encode(query_assistant, allowed_special=set())
|
||
|
raw_text = ""
|
||
|
context_tokens = []
|
||
|
|
||
|
for turn_query, turn_response in reversed(history):
|
||
|
query_text, query_tokens_part = _tokenize_str("user", turn_query)
|
||
|
query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
|
||
|
response_text, response_tokens_part = _tokenize_str("assistant", turn_response)
|
||
|
response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
|
||
|
|
||
|
next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
|
||
|
prev_chat = f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
|
||
|
|
||
|
current_context_size = len(system_tokens) + len(next_context_tokens) + len(context_tokens)
|
||
|
if current_context_size < max_window_size:
|
||
|
context_tokens = next_context_tokens + context_tokens
|
||
|
raw_text = prev_chat + raw_text
|
||
|
else:
|
||
|
break
|
||
|
|
||
|
context_tokens = system_tokens + context_tokens
|
||
|
raw_text = f"{im_start}{system_text}{im_end}" + raw_text
|
||
|
context_tokens += (
|
||
|
nl_tokens
|
||
|
+ im_start_tokens
|
||
|
+ _tokenize_str("user", query)[1]
|
||
|
+ im_end_tokens
|
||
|
+ nl_tokens
|
||
|
+ im_start_tokens
|
||
|
+ tokenizer.encode("assistant")
|
||
|
+ nl_tokens
|
||
|
+ assistant_tokens
|
||
|
)
|
||
|
raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n{query_assistant}"
|
||
|
|
||
|
return raw_text, context_tokens
|
||
|
|
||
|
|
||
|
def decode_tokens(
|
||
|
tokens: Union[torch.LongTensor, TokensType],
|
||
|
tokenizer: PreTrainedTokenizer,
|
||
|
raw_text_len: int = 0,
|
||
|
context_length: int = 0,
|
||
|
errors: str = "replace",
|
||
|
) -> str:
|
||
|
if torch.is_tensor(tokens):
|
||
|
tokens = tokens.cpu().numpy().tolist()
|
||
|
|
||
|
end_reason = f"Gen length {len(tokens)}"
|
||
|
eod_token_idx = context_length
|
||
|
for eod_token_idx in range(context_length, len(tokens)):
|
||
|
if tokens[eod_token_idx] in [tokenizer.im_start_id, tokenizer.im_end_id]:
|
||
|
end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}"
|
||
|
break
|
||
|
|
||
|
decoded = tokenizer.decode(tokens, errors=errors)
|
||
|
|
||
|
decode_tokens = tokenizer.decode(tokens[:eod_token_idx], errors=errors)
|
||
|
trim_decode_tokens = decode_tokens[raw_text_len:]
|
||
|
trim_decode_tokens = trim_decode_tokens.strip()
|
||
|
|
||
|
return decoded, trim_decode_tokens, end_reason
|