From 245d251663974da9bdd3600b9690287a71fa5a3f Mon Sep 17 00:00:00 2001 From: Colin Date: Wed, 10 Jan 2024 11:35:46 +0000 Subject: [PATCH] Refine chat output format. --- qwen/demo.py | 47 +++++++- qwen/modeling_qwen.py | 13 +- qwen/qwen_generation_utils.py | 221 ++++++++++------------------------ 3 files changed, 112 insertions(+), 169 deletions(-) diff --git a/qwen/demo.py b/qwen/demo.py index b010609..f770f77 100644 --- a/qwen/demo.py +++ b/qwen/demo.py @@ -22,6 +22,34 @@ config, kwargs = AutoConfig.from_pretrained( ) model = QWenLMHeadModel(config) +print(model) + +# QWenLMHeadModel( +# (transformer): QWenModel( +# (wte): Embedding(151936, 2048) +# (drop): Dropout(p=0.0, inplace=False) +# (rotary_emb): RotaryEmbedding() +# (h): ModuleList( +# (0-23): 24 x QWenBlock( +# (ln_1): RMSNorm() +# (attn): QWenAttention( +# (c_attn): Linear(in_features=2048, out_features=6144, bias=True) +# (c_proj): Linear(in_features=2048, out_features=2048, bias=False) +# (attn_dropout): Dropout(p=0.0, inplace=False) +# ) +# (ln_2): RMSNorm() +# (mlp): QWenMLP( +# (w1): Linear(in_features=2048, out_features=5504, bias=False) +# (w2): Linear(in_features=2048, out_features=5504, bias=False) +# (c_proj): Linear(in_features=5504, out_features=2048, bias=False) +# ) +# ) +# ) +# (ln_f): RMSNorm() +# ) +# (lm_head): Linear(in_features=2048, out_features=151936, bias=False) +# ) + tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) model = model.from_pretrained( @@ -36,10 +64,21 @@ model = model.from_pretrained( # ) # 第一轮对话 -response, history = model.chat(tokenizer, "你好", history=None) -print(response) +response, history, decode_tokens = model.chat(tokenizer, "你好", "莎是现代汉语的男性的名字,出自《诗经》中的“采采卷", history=None) +print(decode_tokens) # 你好!很高兴为你提供帮助。 # 第二轮对话 -response, history = model.chat(tokenizer, "给我讲一个年轻人奋斗创业最终取得成功的故事。", history=history) -print(response) +# response, history = model.chat(tokenizer, "给我讲一个年轻人奋斗创业最终取得成功的故事。", history=None) +# print(response) + + + +# <|im_start|>system +# You are a helpful assistant.<|im_end|> +# <|im_start|>user +# 你好<|im_end|> +# <|im_start|>assistant +# 莎士比亚是头一个使用“你好”这个词的文学家,他在《哈姆雷特》中写道:“你是谁?你在哪儿? +# ”他的这一段话,通常被认为是最早的使用“你好”这个词的文学记载。这句话在英国语中非常常见, +# 特别是在正式或礼貌的情况下。<|im_end|><|endoftext|> \ No newline at end of file diff --git a/qwen/modeling_qwen.py b/qwen/modeling_qwen.py index fdc871f..c363d1c 100644 --- a/qwen/modeling_qwen.py +++ b/qwen/modeling_qwen.py @@ -413,6 +413,7 @@ class QWenLMHeadModel(QWenPreTrainedModel): self, tokenizer: PreTrainedTokenizer, query: str, + query_assistant: str, history: Optional[HistoryType], system: str = "You are a helpful assistant.", stop_words_ids: Optional[List[List[int]]] = None, @@ -435,13 +436,13 @@ class QWenLMHeadModel(QWenPreTrainedModel): raw_text, context_tokens = make_context( tokenizer, query, + query_assistant, history=history, system=system, - max_window_size=max_window_size, - chat_format=generation_config.chat_format, + max_window_size=max_window_size ) - stop_words_ids.extend(get_stop_words_ids(generation_config.chat_format, tokenizer)) + stop_words_ids.extend(get_stop_words_ids(tokenizer)) input_ids = torch.tensor([context_tokens]).to(self.device) outputs = self.generate( input_ids, @@ -449,17 +450,15 @@ class QWenLMHeadModel(QWenPreTrainedModel): generation_config=generation_config, **kwargs, ) - response = decode_tokens( + decoded, response, end_reason = decode_tokens( outputs[0], tokenizer, raw_text_len=len(raw_text), context_length=len(context_tokens), - chat_format=generation_config.chat_format, - verbose=False, errors="replace", ) history.append((query, response)) - return response, history + return response, history, decoded def generate( self, diff --git a/qwen/qwen_generation_utils.py b/qwen/qwen_generation_utils.py index 4e8e1d8..d3dd7f9 100644 --- a/qwen/qwen_generation_utils.py +++ b/qwen/qwen_generation_utils.py @@ -106,196 +106,101 @@ def get_batch(context_tokens: torch.LongTensor, eod_id: int): return tokens, attention_mask, position_ids -def get_stop_words_ids(chat_format, tokenizer): - if chat_format == "raw": - stop_words_ids = [tokenizer.encode("Human:"), [tokenizer.eod_id]] - elif chat_format == "chatml": - stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]] - else: - raise NotImplementedError(f"Unknown chat format {chat_format!r}") +def get_stop_words_ids(tokenizer): + stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]] return stop_words_ids def make_context( tokenizer: PreTrainedTokenizer, query: str, + query_assistant: str = "", history: List[Tuple[str, str]] = None, system: str = "", - max_window_size: int = 6144, - chat_format: str = "chatml", + max_window_size: int = 6144 ): if history is None: history = [] - if chat_format == "chatml": - im_start, im_end = "<|im_start|>", "<|im_end|>" - im_start_tokens = [tokenizer.im_start_id] - im_end_tokens = [tokenizer.im_end_id] - nl_tokens = tokenizer.encode("\n") + im_start, im_end = "<|im_start|>", "<|im_end|>" + im_start_tokens = [tokenizer.im_start_id] + im_end_tokens = [tokenizer.im_end_id] + nl_tokens = tokenizer.encode("\n") - def _tokenize_str(role, content): - return f"{role}\n{content}", tokenizer.encode( - role, allowed_special=set() - ) + nl_tokens + tokenizer.encode(content, allowed_special=set()) + def _tokenize_str(role, content): + return f"{role}\n{content}", tokenizer.encode( + role, allowed_special=set() + ) + nl_tokens + tokenizer.encode(content, allowed_special=set()) - system_text, system_tokens_part = _tokenize_str("system", system) - system_tokens = im_start_tokens + system_tokens_part + im_end_tokens + system_text, system_tokens_part = _tokenize_str("system", system) + system_tokens = im_start_tokens + system_tokens_part + im_end_tokens + assistant_tokens = tokenizer.encode(query_assistant, allowed_special=set()) + raw_text = "" + context_tokens = [] - raw_text = "" - context_tokens = [] - - for turn_query, turn_response in reversed(history): - query_text, query_tokens_part = _tokenize_str("user", turn_query) - query_tokens = im_start_tokens + query_tokens_part + im_end_tokens - response_text, response_tokens_part = _tokenize_str( - "assistant", turn_response - ) - response_tokens = im_start_tokens + response_tokens_part + im_end_tokens - - next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens - prev_chat = ( - f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}" - ) - - current_context_size = ( - len(system_tokens) + len(next_context_tokens) + len(context_tokens) - ) - if current_context_size < max_window_size: - context_tokens = next_context_tokens + context_tokens - raw_text = prev_chat + raw_text - else: - break - - context_tokens = system_tokens + context_tokens - raw_text = f"{im_start}{system_text}{im_end}" + raw_text - context_tokens += ( - nl_tokens - + im_start_tokens - + _tokenize_str("user", query)[1] - + im_end_tokens - + nl_tokens - + im_start_tokens - + tokenizer.encode("assistant") - + nl_tokens + for turn_query, turn_response in reversed(history): + query_text, query_tokens_part = _tokenize_str("user", turn_query) + query_tokens = im_start_tokens + query_tokens_part + im_end_tokens + response_text, response_tokens_part = _tokenize_str( + "assistant", turn_response ) - raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n" + response_tokens = im_start_tokens + response_tokens_part + im_end_tokens - elif chat_format == "raw": - raw_text = query - context_tokens = tokenizer.encode(raw_text) - else: - raise NotImplementedError(f"Unknown chat format {chat_format!r}") + next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens + prev_chat = ( + f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}" + ) - return raw_text, context_tokens - - -def _decode_default( - tokens: List[int], - *, - stop_words: List[str], - eod_words: List[str], - tokenizer: PreTrainedTokenizer, - raw_text_len: int, - verbose: bool = False, - return_end_reason: bool = False, - errors: str='replace', -): - trim_decode_tokens = tokenizer.decode(tokens, errors=errors)[raw_text_len:] - if verbose: - print("\nRaw Generate: ", trim_decode_tokens) - - end_reason = f"Gen length {len(tokens)}" - for stop_word in stop_words: - trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip() - for eod_word in eod_words: - if eod_word in trim_decode_tokens: - end_reason = f"Gen {eod_word!r}" - trim_decode_tokens = trim_decode_tokens.split(eod_word)[0] - trim_decode_tokens = trim_decode_tokens.strip() - if verbose: - print("\nEnd Reason:", end_reason) - print("\nGenerate: ", trim_decode_tokens) - - if return_end_reason: - return trim_decode_tokens, end_reason - else: - return trim_decode_tokens - - -def _decode_chatml( - tokens: List[int], - *, - stop_words: List[str], - eod_token_ids: List[int], - tokenizer: PreTrainedTokenizer, - raw_text_len: int, - context_length: int, - verbose: bool = False, - return_end_reason: bool = False, - errors: str='replace' -): - end_reason = f"Gen length {len(tokens)}" - eod_token_idx = context_length - for eod_token_idx in range(context_length, len(tokens)): - if tokens[eod_token_idx] in eod_token_ids: - end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}" + current_context_size = ( + len(system_tokens) + len(next_context_tokens) + len(context_tokens) + ) + if current_context_size < max_window_size: + context_tokens = next_context_tokens + context_tokens + raw_text = prev_chat + raw_text + else: break - trim_decode_tokens = tokenizer.decode(tokens[:eod_token_idx], errors=errors)[raw_text_len:] - if verbose: - print("\nRaw Generate w/o EOD:", tokenizer.decode(tokens, errors=errors)[raw_text_len:]) - print("\nRaw Generate:", trim_decode_tokens) - print("\nEnd Reason:", end_reason) - for stop_word in stop_words: - trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip() - trim_decode_tokens = trim_decode_tokens.strip() - if verbose: - print("\nGenerate:", trim_decode_tokens) - - if return_end_reason: - return trim_decode_tokens, end_reason - else: - return trim_decode_tokens + context_tokens = system_tokens + context_tokens + raw_text = f"{im_start}{system_text}{im_end}" + raw_text + context_tokens += ( + nl_tokens + + im_start_tokens + + _tokenize_str("user", query)[1] + + im_end_tokens + + nl_tokens + + im_start_tokens + + tokenizer.encode("assistant") + + nl_tokens + + assistant_tokens + ) + raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n{query_assistant}" + return raw_text, context_tokens def decode_tokens( tokens: Union[torch.LongTensor, TokensType], tokenizer: PreTrainedTokenizer, raw_text_len: int, context_length: int, - chat_format: str, - verbose: bool = False, - return_end_reason: bool = False, errors: str="replace", ) -> str: if torch.is_tensor(tokens): tokens = tokens.cpu().numpy().tolist() - if chat_format == "chatml": - return _decode_chatml( - tokens, - stop_words=[], - eod_token_ids=[tokenizer.im_start_id, tokenizer.im_end_id], - tokenizer=tokenizer, - raw_text_len=raw_text_len, - context_length=context_length, - verbose=verbose, - return_end_reason=return_end_reason, - errors=errors, - ) - elif chat_format == "raw": - return _decode_default( - tokens, - stop_words=["<|endoftext|>"], - eod_words=["<|endoftext|>"], - tokenizer=tokenizer, - raw_text_len=raw_text_len, - verbose=verbose, - return_end_reason=return_end_reason, - errors=errors, - ) - else: - raise NotImplementedError(f"Unknown chat format {chat_format!r}") + end_reason = f"Gen length {len(tokens)}" + eod_token_idx = context_length + for eod_token_idx in range(context_length, len(tokens)): + if tokens[eod_token_idx] in [tokenizer.im_start_id, tokenizer.im_end_id]: + end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}" + break + + decoded = tokenizer.decode(tokens, errors=errors) + + decode_tokens = tokenizer.decode(tokens[:eod_token_idx], errors=errors) + trim_decode_tokens = decode_tokens[raw_text_len:] + trim_decode_tokens = trim_decode_tokens.strip() + + return decoded, trim_decode_tokens, end_reason class StopWordsLogitsProcessor(LogitsProcessor):