Add query file. Refine print tree.

This commit is contained in:
Colin 2025-02-26 16:55:20 +08:00
parent 3c34d12fba
commit bff65b189d
4 changed files with 66 additions and 23 deletions

View File

@ -183,7 +183,7 @@ class MeaningMap:
self.ms_rank_all[start : start + len],
)
def get_tree(self, meaning): # return meaning all sub items
def get_nodetree(self, meaning): # return meaning all sub items
def get_tree_node(ms_map, meaning, vocab_size, parent, seqlist):
ms = ms_map[meaning]
for m in ms[ms >= 0].tolist():
@ -200,6 +200,22 @@ class MeaningMap:
root.seq_node = seqlist
return root
def get_tree(self, meaning):
def get_tree_list(ms_map, meaning, mlist):
ms = ms_map[meaning]
mlist.append(meaning)
mlist.append(-255) # level down marker
for m in ms[ms >= 0].tolist():
if m >= self.vocab_size:
get_tree_list(ms_map, m, mlist)
else:
mlist.append(m)
mlist.append(-1) # level up marker
meaninglist = []
get_tree_list(self.ms_map, meaning, meaninglist)
return meaninglist
def max_length(self):
return max(self.ms_len)
@ -314,6 +330,7 @@ class MeaningDataset(Dataset):
output["labels"] = data.clone()
output["token_type_ids"] = torch.zeros(data.shape)
output["val_mask"] = self.get_seq_mask_tensor(idx_list)
output["meaning"] = [self.seq_meaning[i] for i in idx_list]
return output
def get_token(self, idx): # must equal sequence length
@ -405,17 +422,6 @@ class BatchGroupMeaningDataloader(Dataset):
def __getitem__(self, idx):
return self.meaning_dataset.get_batch(self.indexBatch[idx])
def get_tree(self, idx):
return [self.meaning_dataset.get_tree(i) for i in self.indexBatch[idx]]
def print_tree(self, idx):
idx_list = self.indexBatch[idx]
s = "--------------------------------------------------------\n"
for i in idx_list:
s += self.meaning_dataset.print_tree(i)
s += "--------------------------------------------------------\n"
return s
def detection_collate(batch):
return batch[0]
@ -424,7 +430,6 @@ class BatchGroupMeaningDataloader(Dataset):
self, batch_size=1, num_workers=num_workers, collate_fn=BatchGroupMeaningDataloader.detection_collate
)
if __name__ == "__main__":
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=False)

View File

@ -12,7 +12,8 @@ import dataset.node_tree as nt
if __name__ == "__main__":
checkpoint_path = "log/bigger/version_0/checkpoints/epoch=19-step=98720.ckpt"
# checkpoint_path = "log/bigger/version_0/checkpoints/epoch=19-step=98720.ckpt"
checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt"
qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
qwen.eval()
@ -21,23 +22,24 @@ if __name__ == "__main__":
np.random.seed(conf.seed)
runner = QwenRunner(qwen.llm)
# batch = torch.tensor([[41]], dtype=torch.int64)
# print(runner.ChatTokens(batch).detach().cpu().numpy()[0])
# batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7, 14, 13, 1, 12, 13]], dtype=torch.int64)
# sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
# print(sorted_logits.detach().cpu().numpy())
# print(sorted_indices.detach().cpu().numpy())
val = ds.InitValDataset(conf).dataset
md = val.meaning_dataset
map = md.get_meaning_map()
item = md.get_token(0)
node = map.get_tree(md.get_meaning(0))
node = map.get_nodetree(md.get_meaning(0))
# node.print()
for i in range(1, len(item)):
itemm = [item[:i]]
batch = torch.tensor([item[:i]], dtype=torch.int64)
next_token = runner.ChatTokens(batch, sample=False).detach().cpu().numpy()[0]
sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
next_token = sorted_indices.detach().cpu().numpy()[0][0]
if item[i] != next_token:
node.set_seq_prop(i, "ERR_" + str(next_token))
print(str(item[i]) + " " + str(next_token) + " ERROR")

View File

@ -202,12 +202,11 @@ class QwenRunner:
outputs, loss = self.forwardQWen(input_ids)
next_token_scores = outputs[:, -1, :]
next_token_scores = self.repetition_penalty(input_ids, next_token_scores)
next_token_scores = self.top_p(next_token_scores)
if sample:
next_token_scores = self.top_p(next_token_scores)
return self.sample(next_token_scores)
else:
sorted_logits, sorted_indices = torch.sort(next_token_scores, descending=True)
return sorted_indices[0]
return torch.sort(next_token_scores, descending=True)
@torch.no_grad()
def Chat(

37
wit/query_meaning_freq.py Normal file
View File

@ -0,0 +1,37 @@
import pytorch_lightning as pl
import torch
from model.qwen_module import QwenModule
from model.modeling_wit import QwenRunner
from model.tokenization_qwen import QWenTokenizer
import numpy as np
import configuration
import dataset.dataset as ds
import dataset.node_tree as nt
if __name__ == "__main__":
checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt"
qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
qwen.eval()
conf = qwen.config
torch.manual_seed(conf.seed)
np.random.seed(conf.seed)
train_dataloader, val_dataloader = ds.InitDataset(conf)
loader = train_dataloader.dataset
map = loader.meaning_dataset.get_meaning_map()
trees = {}
for batch in loader:
for m in batch["meaning"]:
trees[m] = map.get_tree(m)
while True:
m = int(input("input meaning: "))
total = 0
for tree in trees.values():
total = total + tree.count(m)
print(f"meaning of {m} count as {total}")