Update readme.
This commit is contained in:
		
							parent
							
								
									9c19c9f285
								
							
						
					
					
						commit
						ebe48f8efc
					
				
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 2.4 KiB  | 
							
								
								
									
										14
									
								
								Readme.md
								
								
								
								
							
							
						
						
									
										14
									
								
								Readme.md
								
								
								
								
							| 
						 | 
					@ -12,16 +12,24 @@ for
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  hidden_states = inputs_embeds
 | 
					  hidden_states = inputs_embeds
 | 
				
			||||||
  for layers : GLMBlock(hidden_states, rotary_pos_emb)
 | 
					  for layers : GLMBlock(hidden_states, rotary_pos_emb)
 | 
				
			||||||
  hidden_states = self.final_layernorm(hidden_states)
 | 
					  hidden_states = RMSNorm(hidden_states)
 | 
				
			||||||
  hidden_states = hidden_states[-1:]
 | 
					  hidden_states = hidden_states[-1:] 截取最后一个sequence
 | 
				
			||||||
  lm_logits = self.output_layer(hidden_states)
 | 
					  lm_logits = self.output_layer(hidden_states)
 | 
				
			||||||
  lm_logits = lm_logits.transpose(0, 1).contiguous()  -> [1, 1, 65024]
 | 
					  lm_logits = lm_logits.transpose(0, 1).contiguous()  -> [1, 1, 65024]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  probs = softmax(lm_logits) -> [1, 65024]
 | 
					  probs = softmax(lm_logits) -> [1, 65024]
 | 
				
			||||||
  next_tokens = torch.multinomial(probs, num_samples=1) 采样  -> [1]  1:batch_num 
 | 
					  next_tokens = torch.multinomial(probs, num_samples=1)  采样  -> [1]  1:batch_num 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  if next_tokens == eos_token_id 推理结束退出循环
 | 
					  if next_tokens == eos_token_id 推理结束退出循环
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  input_ids = torch.cat([input_ids, next_tokens)  -> [1, 7]  1:batch_num
 | 
					  input_ids = torch.cat([input_ids, next_tokens)  -> [1, 7]  1:batch_num
 | 
				
			||||||
 | 
					
 | 
				
			||||||
response = tokenizer.decode(outputs)
 | 
					response = tokenizer.decode(outputs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## RMSNorm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					hidden_states -> [6, 1, 4096]  4096:hidden_size
 | 
				
			||||||
 | 
					variance = hidden_states.pow(2).mean(-1, keepdim=True)  -> [6, 1, 1]
 | 
				
			||||||
 | 
					hidden_states = hidden_states * torch.rsqrt(variance + self.eps) 平方根倒数
 | 
				
			||||||
 | 
					self.weight -> [4096]
 | 
				
			||||||
 | 
					return (self.weight * hidden_states)  -> [6, 1, 4096]
 | 
				
			||||||
| 
						 | 
					@ -68,6 +68,7 @@ class RMSNorm(torch.nn.Module):
 | 
				
			||||||
        input_dtype = hidden_states.dtype
 | 
					        input_dtype = hidden_states.dtype
 | 
				
			||||||
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
 | 
					        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
 | 
				
			||||||
        hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
 | 
					        hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
 | 
				
			||||||
 | 
					        # show.DumpTensorToImage(self.weight, "RMSNorm_weight.png")
 | 
				
			||||||
        return (self.weight * hidden_states).to(input_dtype)
 | 
					        return (self.weight * hidden_states).to(input_dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -29,3 +29,8 @@ print(x.prod(0))
 | 
				
			||||||
print()
 | 
					print()
 | 
				
			||||||
print(x.unsqueeze(1).shape)
 | 
					print(x.unsqueeze(1).shape)
 | 
				
			||||||
print(x.unsqueeze(1).squeeze(1).shape)
 | 
					print(x.unsqueeze(1).squeeze(1).shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					x = torch.tensor([[1, 2], [3, 4]]).to(float)
 | 
				
			||||||
 | 
					print(x.mean(1))
 | 
				
			||||||
 | 
					print(x.mean(0))
 | 
				
			||||||
 | 
					print(x.mean(0, keepdim=True))
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue