diff --git a/qwen/modeling_qwen.py b/qwen/modeling_qwen.py index 7d83af9..303f10e 100644 --- a/qwen/modeling_qwen.py +++ b/qwen/modeling_qwen.py @@ -204,7 +204,6 @@ class QwenRunner: pad_token_id = qwen.config.pad_token_id unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) - this_peer_finished = False while True: outputs = self.forwardQWen(input_ids) next_token_scores = outputs[:, -1, :] @@ -218,11 +217,7 @@ class QwenRunner: unfinished_sequences = unfinished_sequences.mul( next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) ) - if unfinished_sequences.max() == 0: - this_peer_finished = True - - if this_peer_finished: break decoded, response, end_reason = decode_tokens( @@ -233,7 +228,7 @@ class QwenRunner: errors="replace", ) history.append((query, response)) - return response, history, decoded + return input_ids[0].cpu().tolist(), history, decoded def _rotate_half(self, x): x = rearrange(x, "... (j d) -> ... j d", j=2) diff --git a/qwen/research_attention.py b/qwen/research_attention.py index be8ee7e..a4432fe 100644 --- a/qwen/research_attention.py +++ b/qwen/research_attention.py @@ -86,7 +86,6 @@ def DumpQK(query, key, causal_mask, index): class ResearchRunner(QwenRunner): def __init__(self, model): super().__init__(model) - self.tokenDecode = [] def attention(self, attention, query, key, value, causal_mask): query = query.permute(0, 2, 1, 3) @@ -100,20 +99,19 @@ class ResearchRunner(QwenRunner): attn_output = attention.c_proj(context_layer) return attn_output - def sample(self, next_token_scores): - next_tokens = super().sample(next_token_scores) - decoded, response, end_reason = decode_tokens( - next_tokens, - tokenizer, - ) - self.tokenDecode.append(decoded) - return next_tokens - runner = ResearchRunner(model) # 第一轮对话 -response, history, decode_tokens = runner.Chat(tokenizer, "东南亚国家日本的首都是什么市", "日本的首都是") -print(decode_tokens) +output_ids, history, decoded = runner.Chat(tokenizer, "东南亚国家日本的首都是什么市", "日本的首都是") +print(decoded) + +tokens = [] +for i, token in enumerate(output_ids): + de = tokenizer.decode([token]) + de = str(i).zfill(3) + " : " + repr(de) + tokens.append(de) + print(de) + # <|im_start|>system # You are a helpful assistant.<|im_end|> # <|im_start|>user @@ -122,8 +120,7 @@ print(decode_tokens) # 日本的首都东京。<|im_end|> # <|endoftext|> -show.DumpListToFile(runner.tokenDecode, "./temp/token_decode_list.txt") +show.DumpListToFile(tokens, "./temp/token_decode_list.txt") - -if decode_tokens.split("\n")[-2] != """日本的首都东京。<|im_end|>""": +if decoded.split("\n")[-2] != """日本的首都东京。<|im_end|>""": raise ()