Update readme.
This commit is contained in:
parent
9c19c9f285
commit
ebe48f8efc
Binary file not shown.
After Width: | Height: | Size: 2.4 KiB |
16
Readme.md
16
Readme.md
|
@ -12,16 +12,24 @@ for
|
|||
|
||||
hidden_states = inputs_embeds
|
||||
for layers : GLMBlock(hidden_states, rotary_pos_emb)
|
||||
hidden_states = self.final_layernorm(hidden_states)
|
||||
hidden_states = hidden_states[-1:]
|
||||
hidden_states = RMSNorm(hidden_states)
|
||||
hidden_states = hidden_states[-1:] 截取最后一个sequence
|
||||
lm_logits = self.output_layer(hidden_states)
|
||||
lm_logits = lm_logits.transpose(0, 1).contiguous() -> [1, 1, 65024]
|
||||
|
||||
probs = softmax(lm_logits) -> [1, 65024]
|
||||
next_tokens = torch.multinomial(probs, num_samples=1) 采样 -> [1] 1:batch_num
|
||||
next_tokens = torch.multinomial(probs, num_samples=1) 采样 -> [1] 1:batch_num
|
||||
|
||||
if next_tokens == eos_token_id 推理结束退出循环
|
||||
|
||||
input_ids = torch.cat([input_ids, next_tokens) -> [1, 7] 1:batch_num
|
||||
|
||||
response = tokenizer.decode(outputs)
|
||||
response = tokenizer.decode(outputs)
|
||||
|
||||
## RMSNorm
|
||||
|
||||
hidden_states -> [6, 1, 4096] 4096:hidden_size
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True) -> [6, 1, 1]
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps) 平方根倒数
|
||||
self.weight -> [4096]
|
||||
return (self.weight * hidden_states) -> [6, 1, 4096]
|
|
@ -68,6 +68,7 @@ class RMSNorm(torch.nn.Module):
|
|||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||
# show.DumpTensorToImage(self.weight, "RMSNorm_weight.png")
|
||||
return (self.weight * hidden_states).to(input_dtype)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue