Update wit.

This commit is contained in:
Colin 2025-05-22 15:26:43 +08:00
parent f98a951b58
commit 3eb711a97e
4 changed files with 26 additions and 16 deletions

View File

@ -8,11 +8,7 @@ import dataset.dataset as ds
if __name__ == "__main__":
# checkpoint_path = "log/bigger/version_0/checkpoints/epoch=19-step=98720.ckpt"
checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt"
checkpoint_path = "log/bigger/version_3/checkpoints/epoch=46-step=231992.ckpt"
checkpoint_path = "log/bigger/version_8/checkpoints/epoch=49-step=246800.ckpt"
# checkpoint_path = "log/bigger/version_6/checkpoints/epoch=43-step=217184.ckpt"
checkpoint_path = "log/bigger/version_0/checkpoints/epoch=72-step=360328.ckpt"
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
qwen.eval()

View File

@ -167,7 +167,7 @@ class QWenLMHeadModel(nn.Module):
v = value.permute(0, 2, 1, 3)
attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=causal_mask).transpose(1, 2)
if self.hook_attention:
self.hook_attention(query, key, causal_mask, index)
self.hook_attention(q, k, causal_mask, index)
attn_output = attn_output.contiguous()
new_shape = attn_output.size()[:-2] + (block.attn.num_heads * block.attn.head_dim,)
context_layer = attn_output.view(new_shape)

View File

@ -15,10 +15,7 @@ import dataset.dataset as ds
if __name__ == "__main__":
# checkpoint_path = "log/bigger/version_0/checkpoints/epoch=19-step=98720.ckpt"
checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt"
checkpoint_path = "log/bigger/version_3/checkpoints/epoch=46-step=231992.ckpt"
checkpoint_path = "log/bigger/version_8/checkpoints/epoch=49-step=246800.ckpt"
checkpoint_path = "log/bigger/version_0/checkpoints/epoch=72-step=360328.ckpt"
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
qwen.eval()
@ -44,8 +41,25 @@ if __name__ == "__main__":
qwen.llm.hook_attention = DumpQK
batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7]], dtype=torch.int64)
sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
print(sorted_logits.detach().cpu().numpy())
print(sorted_indices.detach().cpu().numpy())
val = ds.InitValDataset(conf).dataset
md = val.meaning_dataset
map = md.get_meaning_map()
item = md.get_token(0)
node = map.get_nodetree(md.get_meaning(0))
# node.print()
batch = torch.tensor([item], dtype=torch.int64)
sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
next_token = sorted_indices.detach().cpu().numpy()[0][0]
node.print()
# batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7]], dtype=torch.int64)
# sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
# print(sorted_logits.detach().cpu().numpy())
# print(sorted_indices.detach().cpu().numpy())

View File

@ -37,8 +37,8 @@ if __name__ == "__main__":
conf.dataset.meaning.val_mask_idx = [0, 0, -1]
config.vocab_size = 32
config.hidden_size = 32 # 128 1024 2048 32
config.intermediate_size = config.hidden_size * 4
config.hidden_size = 128 # 128 1024 2048 32
config.intermediate_size = 256
config.num_hidden_layers = 3 # 6 12 24 3
config.num_attention_heads = 8 # 8 8 16