This commit is contained in:
Colin 2023-12-27 19:58:52 +08:00
parent 0cee40dbb0
commit b3df9e423c
4 changed files with 31 additions and 25 deletions

View File

@ -11,7 +11,8 @@ for:
rotary_pos_emb -> [6, 1, 32, 2] 32:pos的编码维度 2:cos+sin rotary_pos_emb -> [6, 1, 32, 2] 32:pos的编码维度 2:cos+sin
hidden_states = inputs_embeds hidden_states = inputs_embeds
for layers : GLMBlock(hidden_states, rotary_pos_emb) for layers :
GLMBlock(hidden_states, rotary_pos_emb)
hidden_states = RMSNorm(hidden_states) # final_layernorm -> [6, 1, 4096] hidden_states = RMSNorm(hidden_states) # final_layernorm -> [6, 1, 4096]
hidden_states = hidden_states[-1:] 截取最后一个sequence -> [1, 1, 4096] hidden_states = hidden_states[-1:] 截取最后一个sequence -> [1, 1, 4096]
lm_logits = Linear(hidden_states) -> [1, 1, 65024] lm_logits = Linear(hidden_states) -> [1, 1, 65024]
@ -22,7 +23,7 @@ for:
if next_tokens == eos_token_id 推理结束退出循环 if next_tokens == eos_token_id 推理结束退出循环
input_ids = torch.cat([input_ids, next_tokens) -> [1, 7] 1:batch_num input_ids = torch.cat([input_ids, next_tokens]) -> [1, 7] 1:batch_num
response = tokenizer.decode(outputs) response = tokenizer.decode(outputs)

View File

@ -436,6 +436,7 @@ class ChatGLMModel(nn.Module):
input_ids, input_ids,
position_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
tokenizer=None,
): ):
output_hidden_states = ( output_hidden_states = (
output_hidden_states output_hidden_states
@ -449,21 +450,16 @@ class ChatGLMModel(nn.Module):
rotary_pos_emb = rotary_pos_emb[position_ids] rotary_pos_emb = rotary_pos_emb[position_ids]
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
hidden_states = self.encoder(inputs_embeds, rotary_pos_emb) hidden_states_en = self.encoder(inputs_embeds, rotary_pos_emb)
hidden_states = hidden_states[-1:] hidden_states = hidden_states_en[-1:]
lm_logits = self.output_layer(hidden_states) lm_logits = self.output_layer(hidden_states)
# for i in range(16):
# show.DumpTensorToImage(
# self.output_layer.weight[
# int(i * (65024 / 16)) : int((i + 1) * (65024 / 16)), :
# ],
# "generated/output_layer_weight_slice" + str(i) + ".png",
# )
lm_logits = lm_logits.transpose(0, 1).contiguous() lm_logits = lm_logits.transpose(0, 1).contiguous()
return lm_logits next_token_logits = lm_logits[:, -1, :]
probs = nn.functional.softmax(next_token_logits, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
return next_tokens
class ChatGLMForConditionalGeneration(nn.Module): class ChatGLMForConditionalGeneration(nn.Module):
@ -573,7 +569,7 @@ class ChatGLMForConditionalGeneration(nn.Module):
tokenizer, tokenizer,
) )
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) : -1] outputs = outputs.tolist()[0][:]
response = tokenizer.decode(outputs) response = tokenizer.decode(outputs)
history.append({"role": role, "content": query}) history.append({"role": role, "content": query})
return response, history return response, history
@ -604,18 +600,11 @@ class ChatGLMForConditionalGeneration(nn.Module):
) )
model_inputs = {"input_ids": input_ids_in, "position_ids": position_ids_in} model_inputs = {"input_ids": input_ids_in, "position_ids": position_ids_in}
logits = self.transformer( next_tokens = self.transformer(
**model_inputs, **model_inputs,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
tokenizer=tokenizer,
) )
next_token_logits = logits[:, -1, :]
probs = nn.functional.softmax(next_token_logits, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
# response = tokenizer.decode(next_tokens)
# name = "generated/next_tokens" + str(token_count) + "_" + response + "_.png"
# show.DumpTensorToImage(next_token_logits[0], name)
# token_count = token_count + 1
# finished sentences should add a padding token to next # finished sentences should add a padding token to next
pad_token = pad_token_id * isFinished pad_token = pad_token_id * isFinished

View File

@ -45,7 +45,7 @@ glm = glm.eval()
query = "你好" query = "你好"
response, history = glm.chat(tokenizer, query, history=[]) response, history = glm.chat(tokenizer, query, history=[])
print(response) print(response)
if response[1:] != " 你好!有什么可以帮助您的吗": if response.split("\n")[-1] != " 你好!有什么可以帮助您的吗":
raise () raise ()
# query = "colin" # query = "colin"

View File

@ -33,4 +33,20 @@ for i in range(64798):
token.append(str(i) + " : " + tokenizer.decode(i)) token.append(str(i) + " : " + tokenizer.decode(i))
show.DumpListToFile(token, "generated/token.log") show.DumpListToFile(token, "generated/token.log")
# print("=======================")
# for i in range(hidden_states_en.shape[0]):
# hidden_states = hidden_states_en[i : i + 1]
# lm_logits = self.output_layer(hidden_states)
# lm_logits = lm_logits.transpose(0, 1).contiguous()
# next_token_logits = lm_logits[:, -1, :]
# probss = nn.functional.softmax(next_token_logits, dim=-1)
# next_t = torch.multinomial(probss, num_samples=1).squeeze(1)
# response = tokenizer.decode(next_t)
# print(response)
# # name = "generated/next_tokens" + str(token_count) + "_" + response + "_.png"
# # show.DumpTensorToImage(next_token_logits[0], name)
# # token_count = token_count + 1