Refine seed config.
This commit is contained in:
		
							parent
							
								
									e3493163f3
								
							
						
					
					
						commit
						7faf629d45
					
				| 
						 | 
				
			
			@ -19,6 +19,8 @@ if __name__ == "__main__":
 | 
			
		|||
    conf = qwen.config
 | 
			
		||||
    torch.manual_seed(conf.seed)
 | 
			
		||||
    np.random.seed(conf.seed)
 | 
			
		||||
    torch.cuda.manual_seed_all(conf.seed)
 | 
			
		||||
 | 
			
		||||
    runner = ModelRunner(qwen.llm)
 | 
			
		||||
 | 
			
		||||
    # batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7, 14, 13, 1, 12, 13]], dtype=torch.int64)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -172,16 +172,17 @@ class QWenLMHeadModel(nn.Module):
 | 
			
		|||
            new_shape = attn_output.size()[:-2] + (block.attn.num_heads * block.attn.head_dim,)
 | 
			
		||||
            context_layer = attn_output.view(new_shape)
 | 
			
		||||
            attn_outputs = block.attn.c_proj(context_layer)
 | 
			
		||||
            hidden_states = attn_outputs + hidden_states
 | 
			
		||||
 | 
			
		||||
            # RMSNorm
 | 
			
		||||
            layernorm_output = block.ln_2(hidden_states)
 | 
			
		||||
 | 
			
		||||
            # mlp
 | 
			
		||||
            layernorm_input = attn_outputs + hidden_states
 | 
			
		||||
            layernorm_output = block.ln_2(layernorm_input)
 | 
			
		||||
            a1 = block.mlp.w1(layernorm_output)
 | 
			
		||||
            a2 = block.mlp.w2(layernorm_output)
 | 
			
		||||
            intermediate_parallel = a1 * F.silu(a2)
 | 
			
		||||
            mlp_output = block.mlp.c_proj(intermediate_parallel)
 | 
			
		||||
 | 
			
		||||
            hidden_states = layernorm_input + mlp_output
 | 
			
		||||
            hidden_states = hidden_states + mlp_output
 | 
			
		||||
 | 
			
		||||
        hidden_states = transfm.ln_f(hidden_states)
 | 
			
		||||
        hidden_states = hidden_states.view(output_shape)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -43,6 +43,8 @@ if __name__ == "__main__":
 | 
			
		|||
 | 
			
		||||
    torch.manual_seed(conf.seed)
 | 
			
		||||
    np.random.seed(conf.seed)
 | 
			
		||||
    torch.cuda.manual_seed_all(conf.seed)
 | 
			
		||||
 | 
			
		||||
    model = QWenLMHeadModel(conf.model_config)
 | 
			
		||||
    # model = RWKVLMHeadModel(conf.model_config)
 | 
			
		||||
    qwen = LightModule(conf, model)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue