Refine seed config.
This commit is contained in:
parent
e3493163f3
commit
7faf629d45
|
@ -19,6 +19,8 @@ if __name__ == "__main__":
|
||||||
conf = qwen.config
|
conf = qwen.config
|
||||||
torch.manual_seed(conf.seed)
|
torch.manual_seed(conf.seed)
|
||||||
np.random.seed(conf.seed)
|
np.random.seed(conf.seed)
|
||||||
|
torch.cuda.manual_seed_all(conf.seed)
|
||||||
|
|
||||||
runner = ModelRunner(qwen.llm)
|
runner = ModelRunner(qwen.llm)
|
||||||
|
|
||||||
# batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7, 14, 13, 1, 12, 13]], dtype=torch.int64)
|
# batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7, 14, 13, 1, 12, 13]], dtype=torch.int64)
|
||||||
|
|
|
@ -172,16 +172,17 @@ class QWenLMHeadModel(nn.Module):
|
||||||
new_shape = attn_output.size()[:-2] + (block.attn.num_heads * block.attn.head_dim,)
|
new_shape = attn_output.size()[:-2] + (block.attn.num_heads * block.attn.head_dim,)
|
||||||
context_layer = attn_output.view(new_shape)
|
context_layer = attn_output.view(new_shape)
|
||||||
attn_outputs = block.attn.c_proj(context_layer)
|
attn_outputs = block.attn.c_proj(context_layer)
|
||||||
|
hidden_states = attn_outputs + hidden_states
|
||||||
|
|
||||||
|
# RMSNorm
|
||||||
|
layernorm_output = block.ln_2(hidden_states)
|
||||||
|
|
||||||
# mlp
|
# mlp
|
||||||
layernorm_input = attn_outputs + hidden_states
|
|
||||||
layernorm_output = block.ln_2(layernorm_input)
|
|
||||||
a1 = block.mlp.w1(layernorm_output)
|
a1 = block.mlp.w1(layernorm_output)
|
||||||
a2 = block.mlp.w2(layernorm_output)
|
a2 = block.mlp.w2(layernorm_output)
|
||||||
intermediate_parallel = a1 * F.silu(a2)
|
intermediate_parallel = a1 * F.silu(a2)
|
||||||
mlp_output = block.mlp.c_proj(intermediate_parallel)
|
mlp_output = block.mlp.c_proj(intermediate_parallel)
|
||||||
|
hidden_states = hidden_states + mlp_output
|
||||||
hidden_states = layernorm_input + mlp_output
|
|
||||||
|
|
||||||
hidden_states = transfm.ln_f(hidden_states)
|
hidden_states = transfm.ln_f(hidden_states)
|
||||||
hidden_states = hidden_states.view(output_shape)
|
hidden_states = hidden_states.view(output_shape)
|
||||||
|
|
|
@ -43,6 +43,8 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
torch.manual_seed(conf.seed)
|
torch.manual_seed(conf.seed)
|
||||||
np.random.seed(conf.seed)
|
np.random.seed(conf.seed)
|
||||||
|
torch.cuda.manual_seed_all(conf.seed)
|
||||||
|
|
||||||
model = QWenLMHeadModel(conf.model_config)
|
model = QWenLMHeadModel(conf.model_config)
|
||||||
# model = RWKVLMHeadModel(conf.model_config)
|
# model = RWKVLMHeadModel(conf.model_config)
|
||||||
qwen = LightModule(conf, model)
|
qwen = LightModule(conf, model)
|
||||||
|
|
Loading…
Reference in New Issue