Refine model of qwen.
This commit is contained in:
parent
11af10e710
commit
a4fafd460f
|
@ -204,7 +204,6 @@ class QwenRunner:
|
||||||
pad_token_id = qwen.config.pad_token_id
|
pad_token_id = qwen.config.pad_token_id
|
||||||
|
|
||||||
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
|
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
|
||||||
this_peer_finished = False
|
|
||||||
while True:
|
while True:
|
||||||
outputs = self.forwardQWen(input_ids)
|
outputs = self.forwardQWen(input_ids)
|
||||||
next_token_scores = outputs[:, -1, :]
|
next_token_scores = outputs[:, -1, :]
|
||||||
|
@ -218,11 +217,7 @@ class QwenRunner:
|
||||||
unfinished_sequences = unfinished_sequences.mul(
|
unfinished_sequences = unfinished_sequences.mul(
|
||||||
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
||||||
)
|
)
|
||||||
|
|
||||||
if unfinished_sequences.max() == 0:
|
if unfinished_sequences.max() == 0:
|
||||||
this_peer_finished = True
|
|
||||||
|
|
||||||
if this_peer_finished:
|
|
||||||
break
|
break
|
||||||
|
|
||||||
decoded, response, end_reason = decode_tokens(
|
decoded, response, end_reason = decode_tokens(
|
||||||
|
@ -233,7 +228,7 @@ class QwenRunner:
|
||||||
errors="replace",
|
errors="replace",
|
||||||
)
|
)
|
||||||
history.append((query, response))
|
history.append((query, response))
|
||||||
return response, history, decoded
|
return input_ids[0].cpu().tolist(), history, decoded
|
||||||
|
|
||||||
def _rotate_half(self, x):
|
def _rotate_half(self, x):
|
||||||
x = rearrange(x, "... (j d) -> ... j d", j=2)
|
x = rearrange(x, "... (j d) -> ... j d", j=2)
|
||||||
|
|
|
@ -86,7 +86,6 @@ def DumpQK(query, key, causal_mask, index):
|
||||||
class ResearchRunner(QwenRunner):
|
class ResearchRunner(QwenRunner):
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
super().__init__(model)
|
super().__init__(model)
|
||||||
self.tokenDecode = []
|
|
||||||
|
|
||||||
def attention(self, attention, query, key, value, causal_mask):
|
def attention(self, attention, query, key, value, causal_mask):
|
||||||
query = query.permute(0, 2, 1, 3)
|
query = query.permute(0, 2, 1, 3)
|
||||||
|
@ -100,20 +99,19 @@ class ResearchRunner(QwenRunner):
|
||||||
attn_output = attention.c_proj(context_layer)
|
attn_output = attention.c_proj(context_layer)
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
def sample(self, next_token_scores):
|
|
||||||
next_tokens = super().sample(next_token_scores)
|
|
||||||
decoded, response, end_reason = decode_tokens(
|
|
||||||
next_tokens,
|
|
||||||
tokenizer,
|
|
||||||
)
|
|
||||||
self.tokenDecode.append(decoded)
|
|
||||||
return next_tokens
|
|
||||||
|
|
||||||
runner = ResearchRunner(model)
|
runner = ResearchRunner(model)
|
||||||
|
|
||||||
# 第一轮对话
|
# 第一轮对话
|
||||||
response, history, decode_tokens = runner.Chat(tokenizer, "东南亚国家日本的首都是什么市", "日本的首都是")
|
output_ids, history, decoded = runner.Chat(tokenizer, "东南亚国家日本的首都是什么市", "日本的首都是")
|
||||||
print(decode_tokens)
|
print(decoded)
|
||||||
|
|
||||||
|
tokens = []
|
||||||
|
for i, token in enumerate(output_ids):
|
||||||
|
de = tokenizer.decode([token])
|
||||||
|
de = str(i).zfill(3) + " : " + repr(de)
|
||||||
|
tokens.append(de)
|
||||||
|
print(de)
|
||||||
|
|
||||||
# <|im_start|>system
|
# <|im_start|>system
|
||||||
# You are a helpful assistant.<|im_end|>
|
# You are a helpful assistant.<|im_end|>
|
||||||
# <|im_start|>user
|
# <|im_start|>user
|
||||||
|
@ -122,8 +120,7 @@ print(decode_tokens)
|
||||||
# 日本的首都东京。<|im_end|>
|
# 日本的首都东京。<|im_end|>
|
||||||
# <|endoftext|>
|
# <|endoftext|>
|
||||||
|
|
||||||
show.DumpListToFile(runner.tokenDecode, "./temp/token_decode_list.txt")
|
show.DumpListToFile(tokens, "./temp/token_decode_list.txt")
|
||||||
|
|
||||||
|
if decoded.split("\n")[-2] != """日本的首都东京。<|im_end|>""":
|
||||||
if decode_tokens.split("\n")[-2] != """日本的首都东京。<|im_end|>""":
|
|
||||||
raise ()
|
raise ()
|
||||||
|
|
Loading…
Reference in New Issue