This commit is contained in:
Colin 2023-12-27 19:58:52 +08:00
parent 0cee40dbb0
commit b3df9e423c
4 changed files with 31 additions and 25 deletions

View File

@ -11,7 +11,8 @@ for:
rotary_pos_emb -> [6, 1, 32, 2] 32:pos的编码维度 2:cos+sin
hidden_states = inputs_embeds
for layers : GLMBlock(hidden_states, rotary_pos_emb)
for layers :
GLMBlock(hidden_states, rotary_pos_emb)
hidden_states = RMSNorm(hidden_states) # final_layernorm -> [6, 1, 4096]
hidden_states = hidden_states[-1:] 截取最后一个sequence -> [1, 1, 4096]
lm_logits = Linear(hidden_states) -> [1, 1, 65024]
@ -22,7 +23,7 @@ for:
if next_tokens == eos_token_id 推理结束退出循环
input_ids = torch.cat([input_ids, next_tokens) -> [1, 7] 1:batch_num
input_ids = torch.cat([input_ids, next_tokens]) -> [1, 7] 1:batch_num
response = tokenizer.decode(outputs)

View File

@ -436,6 +436,7 @@ class ChatGLMModel(nn.Module):
input_ids,
position_ids: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
tokenizer=None,
):
output_hidden_states = (
output_hidden_states
@ -449,21 +450,16 @@ class ChatGLMModel(nn.Module):
rotary_pos_emb = rotary_pos_emb[position_ids]
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
hidden_states = self.encoder(inputs_embeds, rotary_pos_emb)
hidden_states = hidden_states[-1:]
hidden_states_en = self.encoder(inputs_embeds, rotary_pos_emb)
hidden_states = hidden_states_en[-1:]
lm_logits = self.output_layer(hidden_states)
# for i in range(16):
# show.DumpTensorToImage(
# self.output_layer.weight[
# int(i * (65024 / 16)) : int((i + 1) * (65024 / 16)), :
# ],
# "generated/output_layer_weight_slice" + str(i) + ".png",
# )
lm_logits = lm_logits.transpose(0, 1).contiguous()
return lm_logits
next_token_logits = lm_logits[:, -1, :]
probs = nn.functional.softmax(next_token_logits, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
return next_tokens
class ChatGLMForConditionalGeneration(nn.Module):
@ -573,7 +569,7 @@ class ChatGLMForConditionalGeneration(nn.Module):
tokenizer,
)
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) : -1]
outputs = outputs.tolist()[0][:]
response = tokenizer.decode(outputs)
history.append({"role": role, "content": query})
return response, history
@ -604,18 +600,11 @@ class ChatGLMForConditionalGeneration(nn.Module):
)
model_inputs = {"input_ids": input_ids_in, "position_ids": position_ids_in}
logits = self.transformer(
next_tokens = self.transformer(
**model_inputs,
output_hidden_states=output_hidden_states,
tokenizer=tokenizer,
)
next_token_logits = logits[:, -1, :]
probs = nn.functional.softmax(next_token_logits, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
# response = tokenizer.decode(next_tokens)
# name = "generated/next_tokens" + str(token_count) + "_" + response + "_.png"
# show.DumpTensorToImage(next_token_logits[0], name)
# token_count = token_count + 1
# finished sentences should add a padding token to next
pad_token = pad_token_id * isFinished

View File

@ -45,7 +45,7 @@ glm = glm.eval()
query = "你好"
response, history = glm.chat(tokenizer, query, history=[])
print(response)
if response[1:] != " 你好!有什么可以帮助您的吗":
if response.split("\n")[-1] != " 你好!有什么可以帮助您的吗":
raise ()
# query = "colin"

View File

@ -33,4 +33,20 @@ for i in range(64798):
token.append(str(i) + " : " + tokenizer.decode(i))
show.DumpListToFile(token, "generated/token.log")
# print("=======================")
# for i in range(hidden_states_en.shape[0]):
# hidden_states = hidden_states_en[i : i + 1]
# lm_logits = self.output_layer(hidden_states)
# lm_logits = lm_logits.transpose(0, 1).contiguous()
# next_token_logits = lm_logits[:, -1, :]
# probss = nn.functional.softmax(next_token_logits, dim=-1)
# next_t = torch.multinomial(probss, num_samples=1).squeeze(1)
# response = tokenizer.decode(next_t)
# print(response)
# # name = "generated/next_tokens" + str(token_count) + "_" + response + "_.png"
# # show.DumpTensorToImage(next_token_logits[0], name)
# # token_count = token_count + 1