Compare commits
3 Commits
1b8007e1c3
...
063f722177
Author | SHA1 | Date |
---|---|---|
|
063f722177 | |
|
7d7b4381f8 | |
|
245d251663 |
47
qwen/demo.py
47
qwen/demo.py
|
@ -22,6 +22,34 @@ config, kwargs = AutoConfig.from_pretrained(
|
||||||
)
|
)
|
||||||
model = QWenLMHeadModel(config)
|
model = QWenLMHeadModel(config)
|
||||||
|
|
||||||
|
print(model)
|
||||||
|
|
||||||
|
# QWenLMHeadModel(
|
||||||
|
# (transformer): QWenModel(
|
||||||
|
# (wte): Embedding(151936, 2048)
|
||||||
|
# (drop): Dropout(p=0.0, inplace=False)
|
||||||
|
# (rotary_emb): RotaryEmbedding()
|
||||||
|
# (h): ModuleList(
|
||||||
|
# (0-23): 24 x QWenBlock(
|
||||||
|
# (ln_1): RMSNorm()
|
||||||
|
# (attn): QWenAttention(
|
||||||
|
# (c_attn): Linear(in_features=2048, out_features=6144, bias=True)
|
||||||
|
# (c_proj): Linear(in_features=2048, out_features=2048, bias=False)
|
||||||
|
# (attn_dropout): Dropout(p=0.0, inplace=False)
|
||||||
|
# )
|
||||||
|
# (ln_2): RMSNorm()
|
||||||
|
# (mlp): QWenMLP(
|
||||||
|
# (w1): Linear(in_features=2048, out_features=5504, bias=False)
|
||||||
|
# (w2): Linear(in_features=2048, out_features=5504, bias=False)
|
||||||
|
# (c_proj): Linear(in_features=5504, out_features=2048, bias=False)
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
# (ln_f): RMSNorm()
|
||||||
|
# )
|
||||||
|
# (lm_head): Linear(in_features=2048, out_features=151936, bias=False)
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
||||||
model = model.from_pretrained(
|
model = model.from_pretrained(
|
||||||
|
@ -36,10 +64,21 @@ model = model.from_pretrained(
|
||||||
# )
|
# )
|
||||||
|
|
||||||
# 第一轮对话
|
# 第一轮对话
|
||||||
response, history = model.chat(tokenizer, "你好", history=None)
|
response, history, decode_tokens = model.chat(tokenizer, "你好", "莎是现代汉语的男性的名字,出自《诗经》中的“采采卷", history=None)
|
||||||
print(response)
|
print(decode_tokens)
|
||||||
# 你好!很高兴为你提供帮助。
|
# 你好!很高兴为你提供帮助。
|
||||||
|
|
||||||
# 第二轮对话
|
# 第二轮对话
|
||||||
response, history = model.chat(tokenizer, "给我讲一个年轻人奋斗创业最终取得成功的故事。", history=history)
|
# response, history = model.chat(tokenizer, "给我讲一个年轻人奋斗创业最终取得成功的故事。", history=None)
|
||||||
print(response)
|
# print(response)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# <|im_start|>system
|
||||||
|
# You are a helpful assistant.<|im_end|>
|
||||||
|
# <|im_start|>user
|
||||||
|
# 你好<|im_end|>
|
||||||
|
# <|im_start|>assistant
|
||||||
|
# 莎士比亚是头一个使用“你好”这个词的文学家,他在《哈姆雷特》中写道:“你是谁?你在哪儿?
|
||||||
|
# ”他的这一段话,通常被认为是最早的使用“你好”这个词的文学记载。这句话在英国语中非常常见,
|
||||||
|
# 特别是在正式或礼貌的情况下。<|im_end|><|endoftext|>
|
|
@ -34,7 +34,6 @@ from qwen_generation_utils import (
|
||||||
HistoryType,
|
HistoryType,
|
||||||
make_context,
|
make_context,
|
||||||
decode_tokens,
|
decode_tokens,
|
||||||
get_stop_words_ids,
|
|
||||||
StopWordsLogitsProcessor,
|
StopWordsLogitsProcessor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -413,21 +412,19 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
self,
|
self,
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
query: str,
|
query: str,
|
||||||
|
query_assistant: str,
|
||||||
history: Optional[HistoryType],
|
history: Optional[HistoryType],
|
||||||
system: str = "You are a helpful assistant.",
|
system: str = "You are a helpful assistant.",
|
||||||
stop_words_ids: Optional[List[List[int]]] = None,
|
|
||||||
generation_config: Optional[GenerationConfig] = None,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[str, HistoryType]:
|
) -> Tuple[str, HistoryType]:
|
||||||
generation_config = generation_config if generation_config is not None else self.generation_config
|
generation_config = self.generation_config
|
||||||
|
|
||||||
if history is None:
|
if history is None:
|
||||||
history = []
|
history = []
|
||||||
else:
|
else:
|
||||||
history = copy.deepcopy(history)
|
history = copy.deepcopy(history)
|
||||||
|
|
||||||
if stop_words_ids is None:
|
stop_words_ids = []
|
||||||
stop_words_ids = []
|
|
||||||
|
|
||||||
max_window_size = kwargs.get("max_window_size", None)
|
max_window_size = kwargs.get("max_window_size", None)
|
||||||
if max_window_size is None:
|
if max_window_size is None:
|
||||||
|
@ -435,86 +432,40 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
raw_text, context_tokens = make_context(
|
raw_text, context_tokens = make_context(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
query,
|
query,
|
||||||
|
query_assistant,
|
||||||
history=history,
|
history=history,
|
||||||
system=system,
|
system=system,
|
||||||
max_window_size=max_window_size,
|
max_window_size=max_window_size
|
||||||
chat_format=generation_config.chat_format,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
stop_words_ids.extend(get_stop_words_ids(generation_config.chat_format, tokenizer))
|
stop_words_ids.extend([[tokenizer.im_end_id], [tokenizer.im_start_id]])
|
||||||
input_ids = torch.tensor([context_tokens]).to(self.device)
|
input_ids = torch.tensor([context_tokens]).to(self.device)
|
||||||
outputs = self.generate(
|
outputs = self.generate(
|
||||||
input_ids,
|
input_ids,
|
||||||
stop_words_ids=stop_words_ids,
|
stop_words_ids=stop_words_ids,
|
||||||
generation_config=generation_config,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
response = decode_tokens(
|
decoded, response, end_reason = decode_tokens(
|
||||||
outputs[0],
|
outputs[0],
|
||||||
tokenizer,
|
tokenizer,
|
||||||
raw_text_len=len(raw_text),
|
raw_text_len=len(raw_text),
|
||||||
context_length=len(context_tokens),
|
context_length=len(context_tokens),
|
||||||
chat_format=generation_config.chat_format,
|
|
||||||
verbose=False,
|
|
||||||
errors="replace",
|
errors="replace",
|
||||||
)
|
)
|
||||||
history.append((query, response))
|
history.append((query, response))
|
||||||
return response, history
|
return response, history, decoded
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
inputs: Optional[torch.Tensor] = None,
|
inputs: Optional[torch.Tensor] = None,
|
||||||
generation_config: Optional[GenerationConfig] = None,
|
stop_words_ids = [],
|
||||||
logits_processor: Optional[LogitsProcessorList] = None,
|
|
||||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
|
||||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||||
assistant_model: Optional["PreTrainedModel"] = None,
|
|
||||||
streamer: Optional["BaseStreamer"] = None,
|
streamer: Optional["BaseStreamer"] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[GenerateOutput, torch.LongTensor]:
|
) -> Union[GenerateOutput, torch.LongTensor]:
|
||||||
generation_config = generation_config if generation_config is not None else self.generation_config
|
generation_config = self.generation_config
|
||||||
|
|
||||||
# Process stop_words_ids.
|
|
||||||
stop_words_ids = kwargs.pop("stop_words_ids", None)
|
|
||||||
if stop_words_ids is None and generation_config is not None:
|
|
||||||
stop_words_ids = getattr(generation_config, "stop_words_ids", None)
|
|
||||||
if stop_words_ids is None:
|
|
||||||
stop_words_ids = getattr(generation_config, "stop_words_ids", None)
|
|
||||||
|
|
||||||
if stop_words_ids is not None:
|
|
||||||
stop_words_logits_processor = StopWordsLogitsProcessor(
|
|
||||||
stop_words_ids=stop_words_ids,
|
|
||||||
eos_token_id=generation_config.eos_token_id,
|
|
||||||
)
|
|
||||||
if logits_processor is None:
|
|
||||||
logits_processor = LogitsProcessorList([stop_words_logits_processor])
|
|
||||||
else:
|
|
||||||
logits_processor.append(stop_words_logits_processor)
|
|
||||||
|
|
||||||
return self.generate_base(
|
|
||||||
inputs,
|
|
||||||
generation_config=generation_config,
|
|
||||||
logits_processor=logits_processor,
|
|
||||||
stopping_criteria=stopping_criteria,
|
|
||||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
|
||||||
assistant_model=assistant_model,
|
|
||||||
streamer=streamer,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def generate_base(
|
|
||||||
self,
|
|
||||||
inputs: Optional[torch.Tensor] = None,
|
|
||||||
generation_config: Optional[GenerationConfig] = None,
|
|
||||||
logits_processor: Optional[LogitsProcessorList] = None,
|
|
||||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
|
||||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
|
||||||
assistant_model: Optional["PreTrainedModel"] = None,
|
|
||||||
streamer: Optional["BaseStreamer"] = None,
|
|
||||||
negative_prompt_ids: Optional[torch.Tensor] = None,
|
|
||||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> Union[GenerateOutput, torch.LongTensor]:
|
|
||||||
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
||||||
self._validate_model_class()
|
self._validate_model_class()
|
||||||
|
|
||||||
|
@ -524,8 +475,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
self._validate_model_kwargs(model_kwargs.copy())
|
self._validate_model_kwargs(model_kwargs.copy())
|
||||||
|
|
||||||
# 2. Set generation parameters if not already defined
|
# 2. Set generation parameters if not already defined
|
||||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
|
||||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
|
||||||
|
|
||||||
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
|
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
|
||||||
if model_kwargs.get("attention_mask", None) is None:
|
if model_kwargs.get("attention_mask", None) is None:
|
||||||
|
@ -567,7 +517,12 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_length
|
generation_config.max_length = generation_config.max_new_tokens + input_ids_length
|
||||||
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
|
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
|
||||||
|
|
||||||
# 8. prepare distribution pre_processing samplers
|
|
||||||
|
stop_words_logits_processor = StopWordsLogitsProcessor(
|
||||||
|
stop_words_ids=stop_words_ids,
|
||||||
|
eos_token_id=generation_config.eos_token_id,
|
||||||
|
)
|
||||||
|
logits_processor = LogitsProcessorList([stop_words_logits_processor])
|
||||||
logits_processor = self._get_logits_processor(
|
logits_processor = self._get_logits_processor(
|
||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
input_ids_seq_length=input_ids_length,
|
input_ids_seq_length=input_ids_length,
|
||||||
|
@ -575,18 +530,10 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
model_kwargs=model_kwargs,
|
model_kwargs=model_kwargs,
|
||||||
negative_prompt_ids=negative_prompt_ids,
|
negative_prompt_ids=None,
|
||||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
negative_prompt_attention_mask=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 9. prepare stopping criteria
|
|
||||||
stopping_criteria = self._get_stopping_criteria(
|
|
||||||
generation_config=generation_config, stopping_criteria=stopping_criteria
|
|
||||||
)
|
|
||||||
# 10. go into different generation modes
|
|
||||||
|
|
||||||
# 11. prepare logits warper
|
|
||||||
logits_warper = self._get_logits_warper(generation_config)
|
|
||||||
|
|
||||||
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
|
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
|
||||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||||
|
@ -597,42 +544,24 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
# 13. run sample
|
# 13. run sample
|
||||||
return self.sample_base(
|
|
||||||
input_ids,
|
pad_token_id=generation_config.pad_token_id
|
||||||
logits_processor=logits_processor,
|
eos_token_id=generation_config.eos_token_id
|
||||||
logits_warper=logits_warper,
|
streamer=streamer
|
||||||
stopping_criteria=stopping_criteria,
|
|
||||||
pad_token_id=generation_config.pad_token_id,
|
|
||||||
eos_token_id=generation_config.eos_token_id,
|
# init values
|
||||||
output_scores=generation_config.output_scores,
|
stopping_criteria = self._get_stopping_criteria(
|
||||||
streamer=streamer,
|
generation_config=generation_config, stopping_criteria=StoppingCriteriaList()
|
||||||
**model_kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def sample_base(
|
logits_warper = self._get_logits_warper(generation_config)
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor,
|
|
||||||
logits_processor: Optional[LogitsProcessorList] = None,
|
|
||||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
|
||||||
logits_warper: Optional[LogitsProcessorList] = None,
|
|
||||||
max_length: Optional[int] = None,
|
|
||||||
pad_token_id: Optional[int] = None,
|
|
||||||
eos_token_id: Optional[Union[int, List[int]]] = None,
|
|
||||||
output_scores: Optional[bool] = None,
|
|
||||||
streamer: Optional["BaseStreamer"] = None,
|
|
||||||
**model_kwargs,
|
|
||||||
):
|
|
||||||
# init values
|
|
||||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
|
||||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
|
||||||
|
|
||||||
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
|
||||||
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
||||||
if isinstance(eos_token_id, int):
|
if isinstance(eos_token_id, int):
|
||||||
eos_token_id = [eos_token_id]
|
eos_token_id = [eos_token_id]
|
||||||
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
||||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
|
||||||
|
|
||||||
# init attention / hidden states / scores tuples
|
# init attention / hidden states / scores tuples
|
||||||
scores = None
|
scores = None
|
||||||
|
@ -649,10 +578,10 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
# forward pass to get next token
|
# forward pass to get next token
|
||||||
outputs = self(**model_inputs)
|
outputs = self(**model_inputs)
|
||||||
|
|
||||||
next_token_logits = outputs.logits[:, -1, :]
|
next_token_scores = outputs.logits[:, -1, :]
|
||||||
|
|
||||||
# pre-process distribution
|
# pre-process distribution
|
||||||
next_token_scores = logits_processor(input_ids, next_token_logits)
|
next_token_scores = logits_processor(input_ids, next_token_scores)
|
||||||
next_token_scores = logits_warper(input_ids, next_token_scores)
|
next_token_scores = logits_warper(input_ids, next_token_scores)
|
||||||
|
|
||||||
# sample
|
# sample
|
||||||
|
@ -661,8 +590,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
|
|
||||||
# finished sentences should have their next token be a padding token
|
# finished sentences should have their next token be a padding token
|
||||||
if eos_token_id is not None:
|
if eos_token_id is not None:
|
||||||
if pad_token_id is None:
|
|
||||||
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
|
||||||
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
||||||
|
|
||||||
# update generated ids, model inputs, and length for next step
|
# update generated ids, model inputs, and length for next step
|
||||||
|
|
|
@ -105,197 +105,96 @@ def get_batch(context_tokens: torch.LongTensor, eod_id: int):
|
||||||
)
|
)
|
||||||
return tokens, attention_mask, position_ids
|
return tokens, attention_mask, position_ids
|
||||||
|
|
||||||
|
|
||||||
def get_stop_words_ids(chat_format, tokenizer):
|
|
||||||
if chat_format == "raw":
|
|
||||||
stop_words_ids = [tokenizer.encode("Human:"), [tokenizer.eod_id]]
|
|
||||||
elif chat_format == "chatml":
|
|
||||||
stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]]
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Unknown chat format {chat_format!r}")
|
|
||||||
return stop_words_ids
|
|
||||||
|
|
||||||
|
|
||||||
def make_context(
|
def make_context(
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
query: str,
|
query: str,
|
||||||
|
query_assistant: str = "",
|
||||||
history: List[Tuple[str, str]] = None,
|
history: List[Tuple[str, str]] = None,
|
||||||
system: str = "",
|
system: str = "",
|
||||||
max_window_size: int = 6144,
|
max_window_size: int = 6144
|
||||||
chat_format: str = "chatml",
|
|
||||||
):
|
):
|
||||||
if history is None:
|
if history is None:
|
||||||
history = []
|
history = []
|
||||||
|
|
||||||
if chat_format == "chatml":
|
im_start, im_end = "<|im_start|>", "<|im_end|>"
|
||||||
im_start, im_end = "<|im_start|>", "<|im_end|>"
|
im_start_tokens = [tokenizer.im_start_id]
|
||||||
im_start_tokens = [tokenizer.im_start_id]
|
im_end_tokens = [tokenizer.im_end_id]
|
||||||
im_end_tokens = [tokenizer.im_end_id]
|
nl_tokens = tokenizer.encode("\n")
|
||||||
nl_tokens = tokenizer.encode("\n")
|
|
||||||
|
|
||||||
def _tokenize_str(role, content):
|
def _tokenize_str(role, content):
|
||||||
return f"{role}\n{content}", tokenizer.encode(
|
return f"{role}\n{content}", tokenizer.encode(
|
||||||
role, allowed_special=set()
|
role, allowed_special=set()
|
||||||
) + nl_tokens + tokenizer.encode(content, allowed_special=set())
|
) + nl_tokens + tokenizer.encode(content, allowed_special=set())
|
||||||
|
|
||||||
system_text, system_tokens_part = _tokenize_str("system", system)
|
system_text, system_tokens_part = _tokenize_str("system", system)
|
||||||
system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
|
system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
|
||||||
|
assistant_tokens = tokenizer.encode(query_assistant, allowed_special=set())
|
||||||
|
raw_text = ""
|
||||||
|
context_tokens = []
|
||||||
|
|
||||||
raw_text = ""
|
for turn_query, turn_response in reversed(history):
|
||||||
context_tokens = []
|
query_text, query_tokens_part = _tokenize_str("user", turn_query)
|
||||||
|
query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
|
||||||
for turn_query, turn_response in reversed(history):
|
response_text, response_tokens_part = _tokenize_str(
|
||||||
query_text, query_tokens_part = _tokenize_str("user", turn_query)
|
"assistant", turn_response
|
||||||
query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
|
|
||||||
response_text, response_tokens_part = _tokenize_str(
|
|
||||||
"assistant", turn_response
|
|
||||||
)
|
|
||||||
response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
|
|
||||||
|
|
||||||
next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
|
|
||||||
prev_chat = (
|
|
||||||
f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
|
|
||||||
)
|
|
||||||
|
|
||||||
current_context_size = (
|
|
||||||
len(system_tokens) + len(next_context_tokens) + len(context_tokens)
|
|
||||||
)
|
|
||||||
if current_context_size < max_window_size:
|
|
||||||
context_tokens = next_context_tokens + context_tokens
|
|
||||||
raw_text = prev_chat + raw_text
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
context_tokens = system_tokens + context_tokens
|
|
||||||
raw_text = f"{im_start}{system_text}{im_end}" + raw_text
|
|
||||||
context_tokens += (
|
|
||||||
nl_tokens
|
|
||||||
+ im_start_tokens
|
|
||||||
+ _tokenize_str("user", query)[1]
|
|
||||||
+ im_end_tokens
|
|
||||||
+ nl_tokens
|
|
||||||
+ im_start_tokens
|
|
||||||
+ tokenizer.encode("assistant")
|
|
||||||
+ nl_tokens
|
|
||||||
)
|
)
|
||||||
raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"
|
response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
|
||||||
|
|
||||||
elif chat_format == "raw":
|
next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
|
||||||
raw_text = query
|
prev_chat = (
|
||||||
context_tokens = tokenizer.encode(raw_text)
|
f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
|
||||||
else:
|
)
|
||||||
raise NotImplementedError(f"Unknown chat format {chat_format!r}")
|
|
||||||
|
|
||||||
return raw_text, context_tokens
|
current_context_size = (
|
||||||
|
len(system_tokens) + len(next_context_tokens) + len(context_tokens)
|
||||||
|
)
|
||||||
def _decode_default(
|
if current_context_size < max_window_size:
|
||||||
tokens: List[int],
|
context_tokens = next_context_tokens + context_tokens
|
||||||
*,
|
raw_text = prev_chat + raw_text
|
||||||
stop_words: List[str],
|
else:
|
||||||
eod_words: List[str],
|
|
||||||
tokenizer: PreTrainedTokenizer,
|
|
||||||
raw_text_len: int,
|
|
||||||
verbose: bool = False,
|
|
||||||
return_end_reason: bool = False,
|
|
||||||
errors: str='replace',
|
|
||||||
):
|
|
||||||
trim_decode_tokens = tokenizer.decode(tokens, errors=errors)[raw_text_len:]
|
|
||||||
if verbose:
|
|
||||||
print("\nRaw Generate: ", trim_decode_tokens)
|
|
||||||
|
|
||||||
end_reason = f"Gen length {len(tokens)}"
|
|
||||||
for stop_word in stop_words:
|
|
||||||
trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
|
|
||||||
for eod_word in eod_words:
|
|
||||||
if eod_word in trim_decode_tokens:
|
|
||||||
end_reason = f"Gen {eod_word!r}"
|
|
||||||
trim_decode_tokens = trim_decode_tokens.split(eod_word)[0]
|
|
||||||
trim_decode_tokens = trim_decode_tokens.strip()
|
|
||||||
if verbose:
|
|
||||||
print("\nEnd Reason:", end_reason)
|
|
||||||
print("\nGenerate: ", trim_decode_tokens)
|
|
||||||
|
|
||||||
if return_end_reason:
|
|
||||||
return trim_decode_tokens, end_reason
|
|
||||||
else:
|
|
||||||
return trim_decode_tokens
|
|
||||||
|
|
||||||
|
|
||||||
def _decode_chatml(
|
|
||||||
tokens: List[int],
|
|
||||||
*,
|
|
||||||
stop_words: List[str],
|
|
||||||
eod_token_ids: List[int],
|
|
||||||
tokenizer: PreTrainedTokenizer,
|
|
||||||
raw_text_len: int,
|
|
||||||
context_length: int,
|
|
||||||
verbose: bool = False,
|
|
||||||
return_end_reason: bool = False,
|
|
||||||
errors: str='replace'
|
|
||||||
):
|
|
||||||
end_reason = f"Gen length {len(tokens)}"
|
|
||||||
eod_token_idx = context_length
|
|
||||||
for eod_token_idx in range(context_length, len(tokens)):
|
|
||||||
if tokens[eod_token_idx] in eod_token_ids:
|
|
||||||
end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}"
|
|
||||||
break
|
break
|
||||||
|
|
||||||
trim_decode_tokens = tokenizer.decode(tokens[:eod_token_idx], errors=errors)[raw_text_len:]
|
context_tokens = system_tokens + context_tokens
|
||||||
if verbose:
|
raw_text = f"{im_start}{system_text}{im_end}" + raw_text
|
||||||
print("\nRaw Generate w/o EOD:", tokenizer.decode(tokens, errors=errors)[raw_text_len:])
|
context_tokens += (
|
||||||
print("\nRaw Generate:", trim_decode_tokens)
|
nl_tokens
|
||||||
print("\nEnd Reason:", end_reason)
|
+ im_start_tokens
|
||||||
for stop_word in stop_words:
|
+ _tokenize_str("user", query)[1]
|
||||||
trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
|
+ im_end_tokens
|
||||||
trim_decode_tokens = trim_decode_tokens.strip()
|
+ nl_tokens
|
||||||
if verbose:
|
+ im_start_tokens
|
||||||
print("\nGenerate:", trim_decode_tokens)
|
+ tokenizer.encode("assistant")
|
||||||
|
+ nl_tokens
|
||||||
if return_end_reason:
|
+ assistant_tokens
|
||||||
return trim_decode_tokens, end_reason
|
)
|
||||||
else:
|
raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n{query_assistant}"
|
||||||
return trim_decode_tokens
|
|
||||||
|
|
||||||
|
return raw_text, context_tokens
|
||||||
|
|
||||||
def decode_tokens(
|
def decode_tokens(
|
||||||
tokens: Union[torch.LongTensor, TokensType],
|
tokens: Union[torch.LongTensor, TokensType],
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
raw_text_len: int,
|
raw_text_len: int,
|
||||||
context_length: int,
|
context_length: int,
|
||||||
chat_format: str,
|
|
||||||
verbose: bool = False,
|
|
||||||
return_end_reason: bool = False,
|
|
||||||
errors: str="replace",
|
errors: str="replace",
|
||||||
) -> str:
|
) -> str:
|
||||||
if torch.is_tensor(tokens):
|
if torch.is_tensor(tokens):
|
||||||
tokens = tokens.cpu().numpy().tolist()
|
tokens = tokens.cpu().numpy().tolist()
|
||||||
|
|
||||||
if chat_format == "chatml":
|
end_reason = f"Gen length {len(tokens)}"
|
||||||
return _decode_chatml(
|
eod_token_idx = context_length
|
||||||
tokens,
|
for eod_token_idx in range(context_length, len(tokens)):
|
||||||
stop_words=[],
|
if tokens[eod_token_idx] in [tokenizer.im_start_id, tokenizer.im_end_id]:
|
||||||
eod_token_ids=[tokenizer.im_start_id, tokenizer.im_end_id],
|
end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}"
|
||||||
tokenizer=tokenizer,
|
break
|
||||||
raw_text_len=raw_text_len,
|
|
||||||
context_length=context_length,
|
decoded = tokenizer.decode(tokens, errors=errors)
|
||||||
verbose=verbose,
|
|
||||||
return_end_reason=return_end_reason,
|
decode_tokens = tokenizer.decode(tokens[:eod_token_idx], errors=errors)
|
||||||
errors=errors,
|
trim_decode_tokens = decode_tokens[raw_text_len:]
|
||||||
)
|
trim_decode_tokens = trim_decode_tokens.strip()
|
||||||
elif chat_format == "raw":
|
|
||||||
return _decode_default(
|
return decoded, trim_decode_tokens, end_reason
|
||||||
tokens,
|
|
||||||
stop_words=["<|endoftext|>"],
|
|
||||||
eod_words=["<|endoftext|>"],
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
raw_text_len=raw_text_len,
|
|
||||||
verbose=verbose,
|
|
||||||
return_end_reason=return_end_reason,
|
|
||||||
errors=errors,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Unknown chat format {chat_format!r}")
|
|
||||||
|
|
||||||
|
|
||||||
class StopWordsLogitsProcessor(LogitsProcessor):
|
class StopWordsLogitsProcessor(LogitsProcessor):
|
||||||
|
|
Loading…
Reference in New Issue