Refine chat output format.
This commit is contained in:
parent
1b8007e1c3
commit
245d251663
47
qwen/demo.py
47
qwen/demo.py
|
@ -22,6 +22,34 @@ config, kwargs = AutoConfig.from_pretrained(
|
|||
)
|
||||
model = QWenLMHeadModel(config)
|
||||
|
||||
print(model)
|
||||
|
||||
# QWenLMHeadModel(
|
||||
# (transformer): QWenModel(
|
||||
# (wte): Embedding(151936, 2048)
|
||||
# (drop): Dropout(p=0.0, inplace=False)
|
||||
# (rotary_emb): RotaryEmbedding()
|
||||
# (h): ModuleList(
|
||||
# (0-23): 24 x QWenBlock(
|
||||
# (ln_1): RMSNorm()
|
||||
# (attn): QWenAttention(
|
||||
# (c_attn): Linear(in_features=2048, out_features=6144, bias=True)
|
||||
# (c_proj): Linear(in_features=2048, out_features=2048, bias=False)
|
||||
# (attn_dropout): Dropout(p=0.0, inplace=False)
|
||||
# )
|
||||
# (ln_2): RMSNorm()
|
||||
# (mlp): QWenMLP(
|
||||
# (w1): Linear(in_features=2048, out_features=5504, bias=False)
|
||||
# (w2): Linear(in_features=2048, out_features=5504, bias=False)
|
||||
# (c_proj): Linear(in_features=5504, out_features=2048, bias=False)
|
||||
# )
|
||||
# )
|
||||
# )
|
||||
# (ln_f): RMSNorm()
|
||||
# )
|
||||
# (lm_head): Linear(in_features=2048, out_features=151936, bias=False)
|
||||
# )
|
||||
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
||||
model = model.from_pretrained(
|
||||
|
@ -36,10 +64,21 @@ model = model.from_pretrained(
|
|||
# )
|
||||
|
||||
# 第一轮对话
|
||||
response, history = model.chat(tokenizer, "你好", history=None)
|
||||
print(response)
|
||||
response, history, decode_tokens = model.chat(tokenizer, "你好", "莎是现代汉语的男性的名字,出自《诗经》中的“采采卷", history=None)
|
||||
print(decode_tokens)
|
||||
# 你好!很高兴为你提供帮助。
|
||||
|
||||
# 第二轮对话
|
||||
response, history = model.chat(tokenizer, "给我讲一个年轻人奋斗创业最终取得成功的故事。", history=history)
|
||||
print(response)
|
||||
# response, history = model.chat(tokenizer, "给我讲一个年轻人奋斗创业最终取得成功的故事。", history=None)
|
||||
# print(response)
|
||||
|
||||
|
||||
|
||||
# <|im_start|>system
|
||||
# You are a helpful assistant.<|im_end|>
|
||||
# <|im_start|>user
|
||||
# 你好<|im_end|>
|
||||
# <|im_start|>assistant
|
||||
# 莎士比亚是头一个使用“你好”这个词的文学家,他在《哈姆雷特》中写道:“你是谁?你在哪儿?
|
||||
# ”他的这一段话,通常被认为是最早的使用“你好”这个词的文学记载。这句话在英国语中非常常见,
|
||||
# 特别是在正式或礼貌的情况下。<|im_end|><|endoftext|>
|
|
@ -413,6 +413,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
self,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
query: str,
|
||||
query_assistant: str,
|
||||
history: Optional[HistoryType],
|
||||
system: str = "You are a helpful assistant.",
|
||||
stop_words_ids: Optional[List[List[int]]] = None,
|
||||
|
@ -435,13 +436,13 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
raw_text, context_tokens = make_context(
|
||||
tokenizer,
|
||||
query,
|
||||
query_assistant,
|
||||
history=history,
|
||||
system=system,
|
||||
max_window_size=max_window_size,
|
||||
chat_format=generation_config.chat_format,
|
||||
max_window_size=max_window_size
|
||||
)
|
||||
|
||||
stop_words_ids.extend(get_stop_words_ids(generation_config.chat_format, tokenizer))
|
||||
stop_words_ids.extend(get_stop_words_ids(tokenizer))
|
||||
input_ids = torch.tensor([context_tokens]).to(self.device)
|
||||
outputs = self.generate(
|
||||
input_ids,
|
||||
|
@ -449,17 +450,15 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
generation_config=generation_config,
|
||||
**kwargs,
|
||||
)
|
||||
response = decode_tokens(
|
||||
decoded, response, end_reason = decode_tokens(
|
||||
outputs[0],
|
||||
tokenizer,
|
||||
raw_text_len=len(raw_text),
|
||||
context_length=len(context_tokens),
|
||||
chat_format=generation_config.chat_format,
|
||||
verbose=False,
|
||||
errors="replace",
|
||||
)
|
||||
history.append((query, response))
|
||||
return response, history
|
||||
return response, history, decoded
|
||||
|
||||
def generate(
|
||||
self,
|
||||
|
|
|
@ -106,28 +106,22 @@ def get_batch(context_tokens: torch.LongTensor, eod_id: int):
|
|||
return tokens, attention_mask, position_ids
|
||||
|
||||
|
||||
def get_stop_words_ids(chat_format, tokenizer):
|
||||
if chat_format == "raw":
|
||||
stop_words_ids = [tokenizer.encode("Human:"), [tokenizer.eod_id]]
|
||||
elif chat_format == "chatml":
|
||||
def get_stop_words_ids(tokenizer):
|
||||
stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]]
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown chat format {chat_format!r}")
|
||||
return stop_words_ids
|
||||
|
||||
|
||||
def make_context(
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
query: str,
|
||||
query_assistant: str = "",
|
||||
history: List[Tuple[str, str]] = None,
|
||||
system: str = "",
|
||||
max_window_size: int = 6144,
|
||||
chat_format: str = "chatml",
|
||||
max_window_size: int = 6144
|
||||
):
|
||||
if history is None:
|
||||
history = []
|
||||
|
||||
if chat_format == "chatml":
|
||||
im_start, im_end = "<|im_start|>", "<|im_end|>"
|
||||
im_start_tokens = [tokenizer.im_start_id]
|
||||
im_end_tokens = [tokenizer.im_end_id]
|
||||
|
@ -140,7 +134,7 @@ def make_context(
|
|||
|
||||
system_text, system_tokens_part = _tokenize_str("system", system)
|
||||
system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
|
||||
|
||||
assistant_tokens = tokenizer.encode(query_assistant, allowed_special=set())
|
||||
raw_text = ""
|
||||
context_tokens = []
|
||||
|
||||
|
@ -177,125 +171,36 @@ def make_context(
|
|||
+ im_start_tokens
|
||||
+ tokenizer.encode("assistant")
|
||||
+ nl_tokens
|
||||
+ assistant_tokens
|
||||
)
|
||||
raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"
|
||||
|
||||
elif chat_format == "raw":
|
||||
raw_text = query
|
||||
context_tokens = tokenizer.encode(raw_text)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown chat format {chat_format!r}")
|
||||
raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n{query_assistant}"
|
||||
|
||||
return raw_text, context_tokens
|
||||
|
||||
|
||||
def _decode_default(
|
||||
tokens: List[int],
|
||||
*,
|
||||
stop_words: List[str],
|
||||
eod_words: List[str],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
raw_text_len: int,
|
||||
verbose: bool = False,
|
||||
return_end_reason: bool = False,
|
||||
errors: str='replace',
|
||||
):
|
||||
trim_decode_tokens = tokenizer.decode(tokens, errors=errors)[raw_text_len:]
|
||||
if verbose:
|
||||
print("\nRaw Generate: ", trim_decode_tokens)
|
||||
|
||||
end_reason = f"Gen length {len(tokens)}"
|
||||
for stop_word in stop_words:
|
||||
trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
|
||||
for eod_word in eod_words:
|
||||
if eod_word in trim_decode_tokens:
|
||||
end_reason = f"Gen {eod_word!r}"
|
||||
trim_decode_tokens = trim_decode_tokens.split(eod_word)[0]
|
||||
trim_decode_tokens = trim_decode_tokens.strip()
|
||||
if verbose:
|
||||
print("\nEnd Reason:", end_reason)
|
||||
print("\nGenerate: ", trim_decode_tokens)
|
||||
|
||||
if return_end_reason:
|
||||
return trim_decode_tokens, end_reason
|
||||
else:
|
||||
return trim_decode_tokens
|
||||
|
||||
|
||||
def _decode_chatml(
|
||||
tokens: List[int],
|
||||
*,
|
||||
stop_words: List[str],
|
||||
eod_token_ids: List[int],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
raw_text_len: int,
|
||||
context_length: int,
|
||||
verbose: bool = False,
|
||||
return_end_reason: bool = False,
|
||||
errors: str='replace'
|
||||
):
|
||||
end_reason = f"Gen length {len(tokens)}"
|
||||
eod_token_idx = context_length
|
||||
for eod_token_idx in range(context_length, len(tokens)):
|
||||
if tokens[eod_token_idx] in eod_token_ids:
|
||||
end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}"
|
||||
break
|
||||
|
||||
trim_decode_tokens = tokenizer.decode(tokens[:eod_token_idx], errors=errors)[raw_text_len:]
|
||||
if verbose:
|
||||
print("\nRaw Generate w/o EOD:", tokenizer.decode(tokens, errors=errors)[raw_text_len:])
|
||||
print("\nRaw Generate:", trim_decode_tokens)
|
||||
print("\nEnd Reason:", end_reason)
|
||||
for stop_word in stop_words:
|
||||
trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
|
||||
trim_decode_tokens = trim_decode_tokens.strip()
|
||||
if verbose:
|
||||
print("\nGenerate:", trim_decode_tokens)
|
||||
|
||||
if return_end_reason:
|
||||
return trim_decode_tokens, end_reason
|
||||
else:
|
||||
return trim_decode_tokens
|
||||
|
||||
|
||||
def decode_tokens(
|
||||
tokens: Union[torch.LongTensor, TokensType],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
raw_text_len: int,
|
||||
context_length: int,
|
||||
chat_format: str,
|
||||
verbose: bool = False,
|
||||
return_end_reason: bool = False,
|
||||
errors: str="replace",
|
||||
) -> str:
|
||||
if torch.is_tensor(tokens):
|
||||
tokens = tokens.cpu().numpy().tolist()
|
||||
|
||||
if chat_format == "chatml":
|
||||
return _decode_chatml(
|
||||
tokens,
|
||||
stop_words=[],
|
||||
eod_token_ids=[tokenizer.im_start_id, tokenizer.im_end_id],
|
||||
tokenizer=tokenizer,
|
||||
raw_text_len=raw_text_len,
|
||||
context_length=context_length,
|
||||
verbose=verbose,
|
||||
return_end_reason=return_end_reason,
|
||||
errors=errors,
|
||||
)
|
||||
elif chat_format == "raw":
|
||||
return _decode_default(
|
||||
tokens,
|
||||
stop_words=["<|endoftext|>"],
|
||||
eod_words=["<|endoftext|>"],
|
||||
tokenizer=tokenizer,
|
||||
raw_text_len=raw_text_len,
|
||||
verbose=verbose,
|
||||
return_end_reason=return_end_reason,
|
||||
errors=errors,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown chat format {chat_format!r}")
|
||||
end_reason = f"Gen length {len(tokens)}"
|
||||
eod_token_idx = context_length
|
||||
for eod_token_idx in range(context_length, len(tokens)):
|
||||
if tokens[eod_token_idx] in [tokenizer.im_start_id, tokenizer.im_end_id]:
|
||||
end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}"
|
||||
break
|
||||
|
||||
decoded = tokenizer.decode(tokens, errors=errors)
|
||||
|
||||
decode_tokens = tokenizer.decode(tokens[:eod_token_idx], errors=errors)
|
||||
trim_decode_tokens = decode_tokens[raw_text_len:]
|
||||
trim_decode_tokens = trim_decode_tokens.strip()
|
||||
|
||||
return decoded, trim_decode_tokens, end_reason
|
||||
|
||||
|
||||
class StopWordsLogitsProcessor(LogitsProcessor):
|
||||
|
|
Loading…
Reference in New Issue