Refine seed config.
This commit is contained in:
parent
e3493163f3
commit
7faf629d45
|
@ -19,6 +19,8 @@ if __name__ == "__main__":
|
|||
conf = qwen.config
|
||||
torch.manual_seed(conf.seed)
|
||||
np.random.seed(conf.seed)
|
||||
torch.cuda.manual_seed_all(conf.seed)
|
||||
|
||||
runner = ModelRunner(qwen.llm)
|
||||
|
||||
# batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7, 14, 13, 1, 12, 13]], dtype=torch.int64)
|
||||
|
|
|
@ -172,16 +172,17 @@ class QWenLMHeadModel(nn.Module):
|
|||
new_shape = attn_output.size()[:-2] + (block.attn.num_heads * block.attn.head_dim,)
|
||||
context_layer = attn_output.view(new_shape)
|
||||
attn_outputs = block.attn.c_proj(context_layer)
|
||||
hidden_states = attn_outputs + hidden_states
|
||||
|
||||
# RMSNorm
|
||||
layernorm_output = block.ln_2(hidden_states)
|
||||
|
||||
# mlp
|
||||
layernorm_input = attn_outputs + hidden_states
|
||||
layernorm_output = block.ln_2(layernorm_input)
|
||||
a1 = block.mlp.w1(layernorm_output)
|
||||
a2 = block.mlp.w2(layernorm_output)
|
||||
intermediate_parallel = a1 * F.silu(a2)
|
||||
mlp_output = block.mlp.c_proj(intermediate_parallel)
|
||||
|
||||
hidden_states = layernorm_input + mlp_output
|
||||
hidden_states = hidden_states + mlp_output
|
||||
|
||||
hidden_states = transfm.ln_f(hidden_states)
|
||||
hidden_states = hidden_states.view(output_shape)
|
||||
|
|
|
@ -43,6 +43,8 @@ if __name__ == "__main__":
|
|||
|
||||
torch.manual_seed(conf.seed)
|
||||
np.random.seed(conf.seed)
|
||||
torch.cuda.manual_seed_all(conf.seed)
|
||||
|
||||
model = QWenLMHeadModel(conf.model_config)
|
||||
# model = RWKVLMHeadModel(conf.model_config)
|
||||
qwen = LightModule(conf, model)
|
||||
|
|
Loading…
Reference in New Issue