diff --git a/wit/inference.py b/wit/inference.py index 48dfbc6..ffba9df 100644 --- a/wit/inference.py +++ b/wit/inference.py @@ -19,6 +19,8 @@ if __name__ == "__main__": conf = qwen.config torch.manual_seed(conf.seed) np.random.seed(conf.seed) + torch.cuda.manual_seed_all(conf.seed) + runner = ModelRunner(qwen.llm) # batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7, 14, 13, 1, 12, 13]], dtype=torch.int64) diff --git a/wit/model/modeling_wit.py b/wit/model/modeling_wit.py index 6d90e44..13595ab 100644 --- a/wit/model/modeling_wit.py +++ b/wit/model/modeling_wit.py @@ -172,16 +172,17 @@ class QWenLMHeadModel(nn.Module): new_shape = attn_output.size()[:-2] + (block.attn.num_heads * block.attn.head_dim,) context_layer = attn_output.view(new_shape) attn_outputs = block.attn.c_proj(context_layer) + hidden_states = attn_outputs + hidden_states + + # RMSNorm + layernorm_output = block.ln_2(hidden_states) # mlp - layernorm_input = attn_outputs + hidden_states - layernorm_output = block.ln_2(layernorm_input) a1 = block.mlp.w1(layernorm_output) a2 = block.mlp.w2(layernorm_output) intermediate_parallel = a1 * F.silu(a2) mlp_output = block.mlp.c_proj(intermediate_parallel) - - hidden_states = layernorm_input + mlp_output + hidden_states = hidden_states + mlp_output hidden_states = transfm.ln_f(hidden_states) hidden_states = hidden_states.view(output_shape) diff --git a/wit/train.py b/wit/train.py index 0b65ca9..4f44fde 100644 --- a/wit/train.py +++ b/wit/train.py @@ -43,6 +43,8 @@ if __name__ == "__main__": torch.manual_seed(conf.seed) np.random.seed(conf.seed) + torch.cuda.manual_seed_all(conf.seed) + model = QWenLMHeadModel(conf.model_config) # model = RWKVLMHeadModel(conf.model_config) qwen = LightModule(conf, model)