Add query file. Refine print tree.
This commit is contained in:
parent
3c34d12fba
commit
bff65b189d
|
@ -183,7 +183,7 @@ class MeaningMap:
|
||||||
self.ms_rank_all[start : start + len],
|
self.ms_rank_all[start : start + len],
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_tree(self, meaning): # return meaning all sub items
|
def get_nodetree(self, meaning): # return meaning all sub items
|
||||||
def get_tree_node(ms_map, meaning, vocab_size, parent, seqlist):
|
def get_tree_node(ms_map, meaning, vocab_size, parent, seqlist):
|
||||||
ms = ms_map[meaning]
|
ms = ms_map[meaning]
|
||||||
for m in ms[ms >= 0].tolist():
|
for m in ms[ms >= 0].tolist():
|
||||||
|
@ -200,6 +200,22 @@ class MeaningMap:
|
||||||
root.seq_node = seqlist
|
root.seq_node = seqlist
|
||||||
return root
|
return root
|
||||||
|
|
||||||
|
def get_tree(self, meaning):
|
||||||
|
def get_tree_list(ms_map, meaning, mlist):
|
||||||
|
ms = ms_map[meaning]
|
||||||
|
mlist.append(meaning)
|
||||||
|
mlist.append(-255) # level down marker
|
||||||
|
for m in ms[ms >= 0].tolist():
|
||||||
|
if m >= self.vocab_size:
|
||||||
|
get_tree_list(ms_map, m, mlist)
|
||||||
|
else:
|
||||||
|
mlist.append(m)
|
||||||
|
mlist.append(-1) # level up marker
|
||||||
|
|
||||||
|
meaninglist = []
|
||||||
|
get_tree_list(self.ms_map, meaning, meaninglist)
|
||||||
|
return meaninglist
|
||||||
|
|
||||||
def max_length(self):
|
def max_length(self):
|
||||||
return max(self.ms_len)
|
return max(self.ms_len)
|
||||||
|
|
||||||
|
@ -314,6 +330,7 @@ class MeaningDataset(Dataset):
|
||||||
output["labels"] = data.clone()
|
output["labels"] = data.clone()
|
||||||
output["token_type_ids"] = torch.zeros(data.shape)
|
output["token_type_ids"] = torch.zeros(data.shape)
|
||||||
output["val_mask"] = self.get_seq_mask_tensor(idx_list)
|
output["val_mask"] = self.get_seq_mask_tensor(idx_list)
|
||||||
|
output["meaning"] = [self.seq_meaning[i] for i in idx_list]
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def get_token(self, idx): # must equal sequence length
|
def get_token(self, idx): # must equal sequence length
|
||||||
|
@ -405,17 +422,6 @@ class BatchGroupMeaningDataloader(Dataset):
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
return self.meaning_dataset.get_batch(self.indexBatch[idx])
|
return self.meaning_dataset.get_batch(self.indexBatch[idx])
|
||||||
|
|
||||||
def get_tree(self, idx):
|
|
||||||
return [self.meaning_dataset.get_tree(i) for i in self.indexBatch[idx]]
|
|
||||||
|
|
||||||
def print_tree(self, idx):
|
|
||||||
idx_list = self.indexBatch[idx]
|
|
||||||
s = "--------------------------------------------------------\n"
|
|
||||||
for i in idx_list:
|
|
||||||
s += self.meaning_dataset.print_tree(i)
|
|
||||||
s += "--------------------------------------------------------\n"
|
|
||||||
return s
|
|
||||||
|
|
||||||
def detection_collate(batch):
|
def detection_collate(batch):
|
||||||
return batch[0]
|
return batch[0]
|
||||||
|
|
||||||
|
@ -424,7 +430,6 @@ class BatchGroupMeaningDataloader(Dataset):
|
||||||
self, batch_size=1, num_workers=num_workers, collate_fn=BatchGroupMeaningDataloader.detection_collate
|
self, batch_size=1, num_workers=num_workers, collate_fn=BatchGroupMeaningDataloader.detection_collate
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=False)
|
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=False)
|
||||||
|
|
|
@ -12,7 +12,8 @@ import dataset.node_tree as nt
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
checkpoint_path = "log/bigger/version_0/checkpoints/epoch=19-step=98720.ckpt"
|
# checkpoint_path = "log/bigger/version_0/checkpoints/epoch=19-step=98720.ckpt"
|
||||||
|
checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt"
|
||||||
|
|
||||||
qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
||||||
qwen.eval()
|
qwen.eval()
|
||||||
|
@ -21,23 +22,24 @@ if __name__ == "__main__":
|
||||||
np.random.seed(conf.seed)
|
np.random.seed(conf.seed)
|
||||||
runner = QwenRunner(qwen.llm)
|
runner = QwenRunner(qwen.llm)
|
||||||
|
|
||||||
# batch = torch.tensor([[41]], dtype=torch.int64)
|
# batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7, 14, 13, 1, 12, 13]], dtype=torch.int64)
|
||||||
# print(runner.ChatTokens(batch).detach().cpu().numpy()[0])
|
# sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
|
||||||
|
# print(sorted_logits.detach().cpu().numpy())
|
||||||
|
# print(sorted_indices.detach().cpu().numpy())
|
||||||
|
|
||||||
val = ds.InitValDataset(conf).dataset
|
val = ds.InitValDataset(conf).dataset
|
||||||
md = val.meaning_dataset
|
md = val.meaning_dataset
|
||||||
|
|
||||||
map = md.get_meaning_map()
|
map = md.get_meaning_map()
|
||||||
|
|
||||||
item = md.get_token(0)
|
item = md.get_token(0)
|
||||||
|
|
||||||
node = map.get_tree(md.get_meaning(0))
|
node = map.get_nodetree(md.get_meaning(0))
|
||||||
# node.print()
|
# node.print()
|
||||||
|
|
||||||
for i in range(1, len(item)):
|
for i in range(1, len(item)):
|
||||||
itemm = [item[:i]]
|
itemm = [item[:i]]
|
||||||
batch = torch.tensor([item[:i]], dtype=torch.int64)
|
batch = torch.tensor([item[:i]], dtype=torch.int64)
|
||||||
next_token = runner.ChatTokens(batch, sample=False).detach().cpu().numpy()[0]
|
sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
|
||||||
|
next_token = sorted_indices.detach().cpu().numpy()[0][0]
|
||||||
if item[i] != next_token:
|
if item[i] != next_token:
|
||||||
node.set_seq_prop(i, "ERR_" + str(next_token))
|
node.set_seq_prop(i, "ERR_" + str(next_token))
|
||||||
print(str(item[i]) + " " + str(next_token) + " ERROR")
|
print(str(item[i]) + " " + str(next_token) + " ERROR")
|
||||||
|
|
|
@ -202,12 +202,11 @@ class QwenRunner:
|
||||||
outputs, loss = self.forwardQWen(input_ids)
|
outputs, loss = self.forwardQWen(input_ids)
|
||||||
next_token_scores = outputs[:, -1, :]
|
next_token_scores = outputs[:, -1, :]
|
||||||
next_token_scores = self.repetition_penalty(input_ids, next_token_scores)
|
next_token_scores = self.repetition_penalty(input_ids, next_token_scores)
|
||||||
next_token_scores = self.top_p(next_token_scores)
|
|
||||||
if sample:
|
if sample:
|
||||||
|
next_token_scores = self.top_p(next_token_scores)
|
||||||
return self.sample(next_token_scores)
|
return self.sample(next_token_scores)
|
||||||
else:
|
else:
|
||||||
sorted_logits, sorted_indices = torch.sort(next_token_scores, descending=True)
|
return torch.sort(next_token_scores, descending=True)
|
||||||
return sorted_indices[0]
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def Chat(
|
def Chat(
|
||||||
|
|
|
@ -0,0 +1,37 @@
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from model.qwen_module import QwenModule
|
||||||
|
from model.modeling_wit import QwenRunner
|
||||||
|
from model.tokenization_qwen import QWenTokenizer
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import configuration
|
||||||
|
import dataset.dataset as ds
|
||||||
|
import dataset.node_tree as nt
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt"
|
||||||
|
|
||||||
|
qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
||||||
|
qwen.eval()
|
||||||
|
conf = qwen.config
|
||||||
|
torch.manual_seed(conf.seed)
|
||||||
|
np.random.seed(conf.seed)
|
||||||
|
|
||||||
|
train_dataloader, val_dataloader = ds.InitDataset(conf)
|
||||||
|
|
||||||
|
loader = train_dataloader.dataset
|
||||||
|
|
||||||
|
map = loader.meaning_dataset.get_meaning_map()
|
||||||
|
trees = {}
|
||||||
|
for batch in loader:
|
||||||
|
for m in batch["meaning"]:
|
||||||
|
trees[m] = map.get_tree(m)
|
||||||
|
while True:
|
||||||
|
m = int(input("input meaning: "))
|
||||||
|
total = 0
|
||||||
|
for tree in trees.values():
|
||||||
|
total = total + tree.count(m)
|
||||||
|
print(f"meaning of {m} count as {total}")
|
Loading…
Reference in New Issue