Update readme.
This commit is contained in:
parent
9c19c9f285
commit
ebe48f8efc
Binary file not shown.
After Width: | Height: | Size: 2.4 KiB |
14
Readme.md
14
Readme.md
|
@ -12,16 +12,24 @@ for
|
||||||
|
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
for layers : GLMBlock(hidden_states, rotary_pos_emb)
|
for layers : GLMBlock(hidden_states, rotary_pos_emb)
|
||||||
hidden_states = self.final_layernorm(hidden_states)
|
hidden_states = RMSNorm(hidden_states)
|
||||||
hidden_states = hidden_states[-1:]
|
hidden_states = hidden_states[-1:] 截取最后一个sequence
|
||||||
lm_logits = self.output_layer(hidden_states)
|
lm_logits = self.output_layer(hidden_states)
|
||||||
lm_logits = lm_logits.transpose(0, 1).contiguous() -> [1, 1, 65024]
|
lm_logits = lm_logits.transpose(0, 1).contiguous() -> [1, 1, 65024]
|
||||||
|
|
||||||
probs = softmax(lm_logits) -> [1, 65024]
|
probs = softmax(lm_logits) -> [1, 65024]
|
||||||
next_tokens = torch.multinomial(probs, num_samples=1) 采样 -> [1] 1:batch_num
|
next_tokens = torch.multinomial(probs, num_samples=1) 采样 -> [1] 1:batch_num
|
||||||
|
|
||||||
if next_tokens == eos_token_id 推理结束退出循环
|
if next_tokens == eos_token_id 推理结束退出循环
|
||||||
|
|
||||||
input_ids = torch.cat([input_ids, next_tokens) -> [1, 7] 1:batch_num
|
input_ids = torch.cat([input_ids, next_tokens) -> [1, 7] 1:batch_num
|
||||||
|
|
||||||
response = tokenizer.decode(outputs)
|
response = tokenizer.decode(outputs)
|
||||||
|
|
||||||
|
## RMSNorm
|
||||||
|
|
||||||
|
hidden_states -> [6, 1, 4096] 4096:hidden_size
|
||||||
|
variance = hidden_states.pow(2).mean(-1, keepdim=True) -> [6, 1, 1]
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance + self.eps) 平方根倒数
|
||||||
|
self.weight -> [4096]
|
||||||
|
return (self.weight * hidden_states) -> [6, 1, 4096]
|
|
@ -68,6 +68,7 @@ class RMSNorm(torch.nn.Module):
|
||||||
input_dtype = hidden_states.dtype
|
input_dtype = hidden_states.dtype
|
||||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||||
|
# show.DumpTensorToImage(self.weight, "RMSNorm_weight.png")
|
||||||
return (self.weight * hidden_states).to(input_dtype)
|
return (self.weight * hidden_states).to(input_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -29,3 +29,8 @@ print(x.prod(0))
|
||||||
print()
|
print()
|
||||||
print(x.unsqueeze(1).shape)
|
print(x.unsqueeze(1).shape)
|
||||||
print(x.unsqueeze(1).squeeze(1).shape)
|
print(x.unsqueeze(1).squeeze(1).shape)
|
||||||
|
|
||||||
|
x = torch.tensor([[1, 2], [3, 4]]).to(float)
|
||||||
|
print(x.mean(1))
|
||||||
|
print(x.mean(0))
|
||||||
|
print(x.mean(0, keepdim=True))
|
||||||
|
|
Loading…
Reference in New Issue