Refine seed config.

This commit is contained in:
Colin 2025-03-14 17:38:24 +08:00
parent e3493163f3
commit 7faf629d45
3 changed files with 9 additions and 4 deletions

View File

@ -19,6 +19,8 @@ if __name__ == "__main__":
conf = qwen.config
torch.manual_seed(conf.seed)
np.random.seed(conf.seed)
torch.cuda.manual_seed_all(conf.seed)
runner = ModelRunner(qwen.llm)
# batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7, 14, 13, 1, 12, 13]], dtype=torch.int64)

View File

@ -172,16 +172,17 @@ class QWenLMHeadModel(nn.Module):
new_shape = attn_output.size()[:-2] + (block.attn.num_heads * block.attn.head_dim,)
context_layer = attn_output.view(new_shape)
attn_outputs = block.attn.c_proj(context_layer)
hidden_states = attn_outputs + hidden_states
# RMSNorm
layernorm_output = block.ln_2(hidden_states)
# mlp
layernorm_input = attn_outputs + hidden_states
layernorm_output = block.ln_2(layernorm_input)
a1 = block.mlp.w1(layernorm_output)
a2 = block.mlp.w2(layernorm_output)
intermediate_parallel = a1 * F.silu(a2)
mlp_output = block.mlp.c_proj(intermediate_parallel)
hidden_states = layernorm_input + mlp_output
hidden_states = hidden_states + mlp_output
hidden_states = transfm.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)

View File

@ -43,6 +43,8 @@ if __name__ == "__main__":
torch.manual_seed(conf.seed)
np.random.seed(conf.seed)
torch.cuda.manual_seed_all(conf.seed)
model = QWenLMHeadModel(conf.model_config)
# model = RWKVLMHeadModel(conf.model_config)
qwen = LightModule(conf, model)