diff --git a/wit/inference.py b/wit/inference.py index b932075..1555494 100644 --- a/wit/inference.py +++ b/wit/inference.py @@ -8,11 +8,7 @@ import dataset.dataset as ds if __name__ == "__main__": - # checkpoint_path = "log/bigger/version_0/checkpoints/epoch=19-step=98720.ckpt" - checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt" - checkpoint_path = "log/bigger/version_3/checkpoints/epoch=46-step=231992.ckpt" - checkpoint_path = "log/bigger/version_8/checkpoints/epoch=49-step=246800.ckpt" - # checkpoint_path = "log/bigger/version_6/checkpoints/epoch=43-step=217184.ckpt" + checkpoint_path = "log/bigger/version_0/checkpoints/epoch=72-step=360328.ckpt" qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path) qwen.eval() diff --git a/wit/model/modeling_wit.py b/wit/model/modeling_wit.py index 13595ab..65fbcab 100644 --- a/wit/model/modeling_wit.py +++ b/wit/model/modeling_wit.py @@ -167,7 +167,7 @@ class QWenLMHeadModel(nn.Module): v = value.permute(0, 2, 1, 3) attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=causal_mask).transpose(1, 2) if self.hook_attention: - self.hook_attention(query, key, causal_mask, index) + self.hook_attention(q, k, causal_mask, index) attn_output = attn_output.contiguous() new_shape = attn_output.size()[:-2] + (block.attn.num_heads * block.attn.head_dim,) context_layer = attn_output.view(new_shape) diff --git a/wit/query_block_output.py b/wit/query_block_output.py index 6d0c9e0..98b9d96 100644 --- a/wit/query_block_output.py +++ b/wit/query_block_output.py @@ -15,10 +15,7 @@ import dataset.dataset as ds if __name__ == "__main__": - # checkpoint_path = "log/bigger/version_0/checkpoints/epoch=19-step=98720.ckpt" - checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt" - checkpoint_path = "log/bigger/version_3/checkpoints/epoch=46-step=231992.ckpt" - checkpoint_path = "log/bigger/version_8/checkpoints/epoch=49-step=246800.ckpt" + checkpoint_path = "log/bigger/version_0/checkpoints/epoch=72-step=360328.ckpt" qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path) qwen.eval() @@ -44,8 +41,25 @@ if __name__ == "__main__": qwen.llm.hook_attention = DumpQK - batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7]], dtype=torch.int64) - sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False) - print(sorted_logits.detach().cpu().numpy()) - print(sorted_indices.detach().cpu().numpy()) + val = ds.InitValDataset(conf).dataset + md = val.meaning_dataset + map = md.get_meaning_map() + item = md.get_token(0) + + node = map.get_nodetree(md.get_meaning(0)) + # node.print() + + batch = torch.tensor([item], dtype=torch.int64) + sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False) + next_token = sorted_indices.detach().cpu().numpy()[0][0] + node.print() + + + + + # batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7]], dtype=torch.int64) + # sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False) + + # print(sorted_logits.detach().cpu().numpy()) + # print(sorted_indices.detach().cpu().numpy()) diff --git a/wit/train.py b/wit/train.py index 73fa4aa..41280f9 100644 --- a/wit/train.py +++ b/wit/train.py @@ -37,8 +37,8 @@ if __name__ == "__main__": conf.dataset.meaning.val_mask_idx = [0, 0, -1] config.vocab_size = 32 - config.hidden_size = 32 # 128 1024 2048 32 - config.intermediate_size = config.hidden_size * 4 + config.hidden_size = 128 # 128 1024 2048 32 + config.intermediate_size = 256 config.num_hidden_layers = 3 # 6 12 24 3 config.num_attention_heads = 8 # 8 8 16