Update readme.

This commit is contained in:
Colin 2023-12-22 20:01:09 +08:00
parent 9c19c9f285
commit ebe48f8efc
4 changed files with 18 additions and 4 deletions

BIN
RMSNorm_weight.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.4 KiB

View File

@ -12,8 +12,8 @@ for
hidden_states = inputs_embeds hidden_states = inputs_embeds
for layers : GLMBlock(hidden_states, rotary_pos_emb) for layers : GLMBlock(hidden_states, rotary_pos_emb)
hidden_states = self.final_layernorm(hidden_states) hidden_states = RMSNorm(hidden_states)
hidden_states = hidden_states[-1:] hidden_states = hidden_states[-1:] 截取最后一个sequence
lm_logits = self.output_layer(hidden_states) lm_logits = self.output_layer(hidden_states)
lm_logits = lm_logits.transpose(0, 1).contiguous() -> [1, 1, 65024] lm_logits = lm_logits.transpose(0, 1).contiguous() -> [1, 1, 65024]
@ -25,3 +25,11 @@ for
input_ids = torch.cat([input_ids, next_tokens) -> [1, 7] 1:batch_num input_ids = torch.cat([input_ids, next_tokens) -> [1, 7] 1:batch_num
response = tokenizer.decode(outputs) response = tokenizer.decode(outputs)
## RMSNorm
hidden_states -> [6, 1, 4096] 4096:hidden_size
variance = hidden_states.pow(2).mean(-1, keepdim=True) -> [6, 1, 1]
hidden_states = hidden_states * torch.rsqrt(variance + self.eps) 平方根倒数
self.weight -> [4096]
return (self.weight * hidden_states) -> [6, 1, 4096]

View File

@ -68,6 +68,7 @@ class RMSNorm(torch.nn.Module):
input_dtype = hidden_states.dtype input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps) hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
# show.DumpTensorToImage(self.weight, "RMSNorm_weight.png")
return (self.weight * hidden_states).to(input_dtype) return (self.weight * hidden_states).to(input_dtype)

View File

@ -29,3 +29,8 @@ print(x.prod(0))
print() print()
print(x.unsqueeze(1).shape) print(x.unsqueeze(1).shape)
print(x.unsqueeze(1).squeeze(1).shape) print(x.unsqueeze(1).squeeze(1).shape)
x = torch.tensor([[1, 2], [3, 4]]).to(float)
print(x.mean(1))
print(x.mean(0))
print(x.mean(0, keepdim=True))