diff --git a/Readme.md b/Readme.md index ba325c4..0742e56 100644 --- a/Readme.md +++ b/Readme.md @@ -11,7 +11,8 @@ for: rotary_pos_emb -> [6, 1, 32, 2] 32:pos的编码维度 2:cos+sin hidden_states = inputs_embeds - for layers : GLMBlock(hidden_states, rotary_pos_emb) + for layers : + GLMBlock(hidden_states, rotary_pos_emb) hidden_states = RMSNorm(hidden_states) # final_layernorm -> [6, 1, 4096] hidden_states = hidden_states[-1:] 截取最后一个sequence -> [1, 1, 4096] lm_logits = Linear(hidden_states) -> [1, 1, 65024] @@ -22,7 +23,7 @@ for: if next_tokens == eos_token_id 推理结束退出循环 - input_ids = torch.cat([input_ids, next_tokens) -> [1, 7] 1:batch_num + input_ids = torch.cat([input_ids, next_tokens]) -> [1, 7] 1:batch_num response = tokenizer.decode(outputs) diff --git a/chatglm/modeling_chatglm.py b/chatglm/modeling_chatglm.py index 9729dd3..29e519c 100644 --- a/chatglm/modeling_chatglm.py +++ b/chatglm/modeling_chatglm.py @@ -436,6 +436,7 @@ class ChatGLMModel(nn.Module): input_ids, position_ids: Optional[torch.Tensor] = None, output_hidden_states: Optional[bool] = None, + tokenizer=None, ): output_hidden_states = ( output_hidden_states @@ -449,21 +450,16 @@ class ChatGLMModel(nn.Module): rotary_pos_emb = rotary_pos_emb[position_ids] rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() - hidden_states = self.encoder(inputs_embeds, rotary_pos_emb) - hidden_states = hidden_states[-1:] + hidden_states_en = self.encoder(inputs_embeds, rotary_pos_emb) + hidden_states = hidden_states_en[-1:] lm_logits = self.output_layer(hidden_states) - - # for i in range(16): - # show.DumpTensorToImage( - # self.output_layer.weight[ - # int(i * (65024 / 16)) : int((i + 1) * (65024 / 16)), : - # ], - # "generated/output_layer_weight_slice" + str(i) + ".png", - # ) - lm_logits = lm_logits.transpose(0, 1).contiguous() - return lm_logits + next_token_logits = lm_logits[:, -1, :] + probs = nn.functional.softmax(next_token_logits, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + + return next_tokens class ChatGLMForConditionalGeneration(nn.Module): @@ -573,7 +569,7 @@ class ChatGLMForConditionalGeneration(nn.Module): tokenizer, ) - outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) : -1] + outputs = outputs.tolist()[0][:] response = tokenizer.decode(outputs) history.append({"role": role, "content": query}) return response, history @@ -604,18 +600,11 @@ class ChatGLMForConditionalGeneration(nn.Module): ) model_inputs = {"input_ids": input_ids_in, "position_ids": position_ids_in} - logits = self.transformer( + next_tokens = self.transformer( **model_inputs, output_hidden_states=output_hidden_states, + tokenizer=tokenizer, ) - next_token_logits = logits[:, -1, :] - probs = nn.functional.softmax(next_token_logits, dim=-1) - next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) - - # response = tokenizer.decode(next_tokens) - # name = "generated/next_tokens" + str(token_count) + "_" + response + "_.png" - # show.DumpTensorToImage(next_token_logits[0], name) - # token_count = token_count + 1 # finished sentences should add a padding token to next pad_token = pad_token_id * isFinished diff --git a/demo.py b/demo.py index 6279dd9..2efbe61 100644 --- a/demo.py +++ b/demo.py @@ -45,7 +45,7 @@ glm = glm.eval() query = "你好" response, history = glm.chat(tokenizer, query, history=[]) print(response) -if response[1:] != " 你好!有什么可以帮助您的吗": +if response.split("\n")[-1] != " 你好!有什么可以帮助您的吗?": raise () # query = "colin" diff --git a/test_tokenizer.py b/test_tokenizer.py index b05fb3e..63f59a9 100644 --- a/test_tokenizer.py +++ b/test_tokenizer.py @@ -33,4 +33,20 @@ for i in range(64798): token.append(str(i) + " : " + tokenizer.decode(i)) show.DumpListToFile(token, "generated/token.log") +# print("=======================") +# for i in range(hidden_states_en.shape[0]): +# hidden_states = hidden_states_en[i : i + 1] +# lm_logits = self.output_layer(hidden_states) +# lm_logits = lm_logits.transpose(0, 1).contiguous() + +# next_token_logits = lm_logits[:, -1, :] +# probss = nn.functional.softmax(next_token_logits, dim=-1) +# next_t = torch.multinomial(probss, num_samples=1).squeeze(1) + +# response = tokenizer.decode(next_t) + +# print(response) +# # name = "generated/next_tokens" + str(token_count) + "_" + response + "_.png" +# # show.DumpTensorToImage(next_token_logits[0], name) +# # token_count = token_count + 1