Update readme.
This commit is contained in:
		
							parent
							
								
									9c19c9f285
								
							
						
					
					
						commit
						ebe48f8efc
					
				
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 2.4 KiB  | 
							
								
								
									
										16
									
								
								Readme.md
								
								
								
								
							
							
						
						
									
										16
									
								
								Readme.md
								
								
								
								
							| 
						 | 
				
			
			@ -12,16 +12,24 @@ for
 | 
			
		|||
 | 
			
		||||
  hidden_states = inputs_embeds
 | 
			
		||||
  for layers : GLMBlock(hidden_states, rotary_pos_emb)
 | 
			
		||||
  hidden_states = self.final_layernorm(hidden_states)
 | 
			
		||||
  hidden_states = hidden_states[-1:]
 | 
			
		||||
  hidden_states = RMSNorm(hidden_states)
 | 
			
		||||
  hidden_states = hidden_states[-1:] 截取最后一个sequence
 | 
			
		||||
  lm_logits = self.output_layer(hidden_states)
 | 
			
		||||
  lm_logits = lm_logits.transpose(0, 1).contiguous()  -> [1, 1, 65024]
 | 
			
		||||
 | 
			
		||||
  probs = softmax(lm_logits) -> [1, 65024]
 | 
			
		||||
  next_tokens = torch.multinomial(probs, num_samples=1) 采样  -> [1]  1:batch_num 
 | 
			
		||||
  next_tokens = torch.multinomial(probs, num_samples=1)  采样  -> [1]  1:batch_num 
 | 
			
		||||
 | 
			
		||||
  if next_tokens == eos_token_id 推理结束退出循环
 | 
			
		||||
 | 
			
		||||
  input_ids = torch.cat([input_ids, next_tokens)  -> [1, 7]  1:batch_num
 | 
			
		||||
 | 
			
		||||
response = tokenizer.decode(outputs)
 | 
			
		||||
response = tokenizer.decode(outputs)
 | 
			
		||||
 | 
			
		||||
## RMSNorm
 | 
			
		||||
 | 
			
		||||
hidden_states -> [6, 1, 4096]  4096:hidden_size
 | 
			
		||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)  -> [6, 1, 1]
 | 
			
		||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps) 平方根倒数
 | 
			
		||||
self.weight -> [4096]
 | 
			
		||||
return (self.weight * hidden_states)  -> [6, 1, 4096]
 | 
			
		||||
| 
						 | 
				
			
			@ -68,6 +68,7 @@ class RMSNorm(torch.nn.Module):
 | 
			
		|||
        input_dtype = hidden_states.dtype
 | 
			
		||||
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
 | 
			
		||||
        hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
 | 
			
		||||
        # show.DumpTensorToImage(self.weight, "RMSNorm_weight.png")
 | 
			
		||||
        return (self.weight * hidden_states).to(input_dtype)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue