Refine model of qwen.

This commit is contained in:
Colin 2024-01-24 21:22:03 +08:00
parent 11af10e710
commit 19491d1f4a
2 changed files with 14 additions and 22 deletions

View File

@ -204,7 +204,6 @@ class QwenRunner:
pad_token_id = qwen.config.pad_token_id pad_token_id = qwen.config.pad_token_id
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
this_peer_finished = False
while True: while True:
outputs = self.forwardQWen(input_ids) outputs = self.forwardQWen(input_ids)
next_token_scores = outputs[:, -1, :] next_token_scores = outputs[:, -1, :]
@ -218,11 +217,7 @@ class QwenRunner:
unfinished_sequences = unfinished_sequences.mul( unfinished_sequences = unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
) )
if unfinished_sequences.max() == 0: if unfinished_sequences.max() == 0:
this_peer_finished = True
if this_peer_finished:
break break
decoded, response, end_reason = decode_tokens( decoded, response, end_reason = decode_tokens(
@ -233,7 +228,7 @@ class QwenRunner:
errors="replace", errors="replace",
) )
history.append((query, response)) history.append((query, response))
return response, history, decoded return input_ids[0].cpu().tolist(), history, decoded
def _rotate_half(self, x): def _rotate_half(self, x):
x = rearrange(x, "... (j d) -> ... j d", j=2) x = rearrange(x, "... (j d) -> ... j d", j=2)

View File

@ -52,7 +52,7 @@ def Dump_tokens_list(model):
context_length=0, context_length=0,
errors="replace", errors="replace",
) )
tokens.append(str(token) + ": " + decoded) tokens.append(str(token).zfill(7) + ": " + repr(decoded))
show.DumpListToFile(tokens, "./temp/qwen_token_list.txt") show.DumpListToFile(tokens, "./temp/qwen_token_list.txt")
@ -86,7 +86,6 @@ def DumpQK(query, key, causal_mask, index):
class ResearchRunner(QwenRunner): class ResearchRunner(QwenRunner):
def __init__(self, model): def __init__(self, model):
super().__init__(model) super().__init__(model)
self.tokenDecode = []
def attention(self, attention, query, key, value, causal_mask): def attention(self, attention, query, key, value, causal_mask):
query = query.permute(0, 2, 1, 3) query = query.permute(0, 2, 1, 3)
@ -100,20 +99,19 @@ class ResearchRunner(QwenRunner):
attn_output = attention.c_proj(context_layer) attn_output = attention.c_proj(context_layer)
return attn_output return attn_output
def sample(self, next_token_scores):
next_tokens = super().sample(next_token_scores)
decoded, response, end_reason = decode_tokens(
next_tokens,
tokenizer,
)
self.tokenDecode.append(decoded)
return next_tokens
runner = ResearchRunner(model) runner = ResearchRunner(model)
# 第一轮对话 # 第一轮对话
response, history, decode_tokens = runner.Chat(tokenizer, "东南亚国家日本的首都是什么市", "日本的首都是") output_ids, history, decoded = runner.Chat(tokenizer, "东南亚国家日本的首都是什么市", "日本的首都是")
print(decode_tokens) print(decoded)
tokens = []
for i, token in enumerate(output_ids):
de = tokenizer.decode([token])
de = str(i).zfill(3) + " : " + repr(de)
tokens.append(de)
print(de)
# <|im_start|>system # <|im_start|>system
# You are a helpful assistant.<|im_end|> # You are a helpful assistant.<|im_end|>
# <|im_start|>user # <|im_start|>user
@ -122,8 +120,7 @@ print(decode_tokens)
# 日本的首都东京。<|im_end|> # 日本的首都东京。<|im_end|>
# <|endoftext|> # <|endoftext|>
show.DumpListToFile(runner.tokenDecode, "./temp/token_decode_list.txt") show.DumpListToFile(tokens, "./temp/token_decode_list.txt")
if decoded.split("\n")[-2] != """日本的首都东京。<|im_end|>""":
if decode_tokens.split("\n")[-2] != """日本的首都东京。<|im_end|>""":
raise () raise ()