Update wit.
This commit is contained in:
parent
f98a951b58
commit
3eb711a97e
|
@ -8,11 +8,7 @@ import dataset.dataset as ds
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
# checkpoint_path = "log/bigger/version_0/checkpoints/epoch=19-step=98720.ckpt"
|
checkpoint_path = "log/bigger/version_0/checkpoints/epoch=72-step=360328.ckpt"
|
||||||
checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt"
|
|
||||||
checkpoint_path = "log/bigger/version_3/checkpoints/epoch=46-step=231992.ckpt"
|
|
||||||
checkpoint_path = "log/bigger/version_8/checkpoints/epoch=49-step=246800.ckpt"
|
|
||||||
# checkpoint_path = "log/bigger/version_6/checkpoints/epoch=43-step=217184.ckpt"
|
|
||||||
|
|
||||||
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
||||||
qwen.eval()
|
qwen.eval()
|
||||||
|
|
|
@ -167,7 +167,7 @@ class QWenLMHeadModel(nn.Module):
|
||||||
v = value.permute(0, 2, 1, 3)
|
v = value.permute(0, 2, 1, 3)
|
||||||
attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=causal_mask).transpose(1, 2)
|
attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=causal_mask).transpose(1, 2)
|
||||||
if self.hook_attention:
|
if self.hook_attention:
|
||||||
self.hook_attention(query, key, causal_mask, index)
|
self.hook_attention(q, k, causal_mask, index)
|
||||||
attn_output = attn_output.contiguous()
|
attn_output = attn_output.contiguous()
|
||||||
new_shape = attn_output.size()[:-2] + (block.attn.num_heads * block.attn.head_dim,)
|
new_shape = attn_output.size()[:-2] + (block.attn.num_heads * block.attn.head_dim,)
|
||||||
context_layer = attn_output.view(new_shape)
|
context_layer = attn_output.view(new_shape)
|
||||||
|
|
|
@ -15,10 +15,7 @@ import dataset.dataset as ds
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
# checkpoint_path = "log/bigger/version_0/checkpoints/epoch=19-step=98720.ckpt"
|
checkpoint_path = "log/bigger/version_0/checkpoints/epoch=72-step=360328.ckpt"
|
||||||
checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt"
|
|
||||||
checkpoint_path = "log/bigger/version_3/checkpoints/epoch=46-step=231992.ckpt"
|
|
||||||
checkpoint_path = "log/bigger/version_8/checkpoints/epoch=49-step=246800.ckpt"
|
|
||||||
|
|
||||||
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
||||||
qwen.eval()
|
qwen.eval()
|
||||||
|
@ -44,8 +41,25 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
qwen.llm.hook_attention = DumpQK
|
qwen.llm.hook_attention = DumpQK
|
||||||
|
|
||||||
batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7]], dtype=torch.int64)
|
|
||||||
sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
|
|
||||||
|
|
||||||
print(sorted_logits.detach().cpu().numpy())
|
val = ds.InitValDataset(conf).dataset
|
||||||
print(sorted_indices.detach().cpu().numpy())
|
md = val.meaning_dataset
|
||||||
|
map = md.get_meaning_map()
|
||||||
|
item = md.get_token(0)
|
||||||
|
|
||||||
|
node = map.get_nodetree(md.get_meaning(0))
|
||||||
|
# node.print()
|
||||||
|
|
||||||
|
batch = torch.tensor([item], dtype=torch.int64)
|
||||||
|
sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
|
||||||
|
next_token = sorted_indices.detach().cpu().numpy()[0][0]
|
||||||
|
node.print()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7]], dtype=torch.int64)
|
||||||
|
# sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
|
||||||
|
|
||||||
|
# print(sorted_logits.detach().cpu().numpy())
|
||||||
|
# print(sorted_indices.detach().cpu().numpy())
|
||||||
|
|
|
@ -37,8 +37,8 @@ if __name__ == "__main__":
|
||||||
conf.dataset.meaning.val_mask_idx = [0, 0, -1]
|
conf.dataset.meaning.val_mask_idx = [0, 0, -1]
|
||||||
|
|
||||||
config.vocab_size = 32
|
config.vocab_size = 32
|
||||||
config.hidden_size = 32 # 128 1024 2048 32
|
config.hidden_size = 128 # 128 1024 2048 32
|
||||||
config.intermediate_size = config.hidden_size * 4
|
config.intermediate_size = 256
|
||||||
config.num_hidden_layers = 3 # 6 12 24 3
|
config.num_hidden_layers = 3 # 6 12 24 3
|
||||||
config.num_attention_heads = 8 # 8 8 16
|
config.num_attention_heads = 8 # 8 8 16
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue