Refine meaning dataset map.

This commit is contained in:
Colin 2025-02-22 16:50:16 +08:00
parent f0469b351c
commit 3a7ce45654
3 changed files with 119 additions and 106 deletions

View File

@ -183,66 +183,18 @@ class MeaningMap:
)
def get_tree(self, meaning): # return meaning all sub items
tree = {}
ms = self.ms_map[meaning]
for m in ms[ms > 0].tolist():
tree[m] = self.get_tree(m) if m >= self.vocab_size else m
return tree
def get_tree_dict(ms_map, vocab_size, meaning): # return meaning all sub items
tree = {}
ms = ms_map[meaning]
for m in ms[ms > 0].tolist():
tree[m] = get_tree_dict(ms_map, vocab_size, m) if m >= self.vocab_size else m
return tree
return get_tree_dict(self.ms_map, self.vocab_size, meaning)
def max_length(self):
return max(self.ms_len)
def print_tree(tree):
def get_tree_node(tree, parent=None):
for key, value in tree.items():
pn = Node(str(key), parent)
if isinstance(value, dict):
get_tree_node(value, pn)
assert isinstance(tree, dict)
root = Node("root")
get_tree_node(tree, root)
treestr = ""
for pre, fill, node in RenderTree(root.children[0] if len(tree) == 1 else root):
treestr += f"{pre}{node.name}\n"
return treestr
def get_tree_indexed_str(tree, data, prefix):
if isinstance(tree, list):
base = ""
qlen = 0
for i, t in enumerate(tree):
s, l = MeaningMap.get_tree_indexed_str(t, data[i], "")
base += s
qlen += l
return (base, qlen)
else:
if isinstance(tree, dict):
base = ""
qlen = 0
last_is_dict = None
for key, value in tree.items():
new_prefix = (len(str(key)) + 2) * " " + prefix
dict_string, l = MeaningMap.get_tree_indexed_str(value, data[qlen:], new_prefix)
if dict_string:
base += "\n" + prefix + str(key) + ": " + dict_string
last_is_dict = True
else:
base += "\n" + prefix + str(data[qlen]) + " " if last_is_dict else str(data[qlen]) + " "
last_is_dict = False
qlen += l
return (base, qlen)
return (None, 1)
def token_frequency(tree, freq):
if isinstance(tree, dict):
for key, value in tree.items():
if key in freq:
freq[key] = freq[key] + 1
else:
freq[key] = 1
MeaningMap.token_frequency(value, freq)
class MeaningDataset(Dataset):
@ -259,17 +211,23 @@ class MeaningDataset(Dataset):
use_cache=True,
):
np.random.seed(seed)
map = MeaningMap(end, vocab_size, max_subitem, min_subitem, use_cache=use_cache)
np.random.seed(seed)
self.start = start
self.end = end
self.vocab_size = vocab_size
self.max_subitem = max_subitem
self.min_subitem = min_subitem
self.use_cache = use_cache
self.min_seq_len = min_seq_len
print("Build MeaningDataset from MeaningMap.")
self.val_mask_level = None
self.val_mask_idx = None
self.tree = []
self.seq = []
self.level = []
self.rank_idx = []
self.rank_all = []
self.seq_meaning = []
map = self.get_meaning_map()
self.m_height = map.ms_height
self.m_weight = map.ms_weight
if size:
@ -281,7 +239,6 @@ class MeaningDataset(Dataset):
for m in meanings:
d, l, i, a = map.get_sequence(m)
if len(d) >= min_seq_len:
self.tree.append({m: map.get_tree(m)})
self.seq.append(d)
self.level.append(l)
self.seq_meaning.append(m)
@ -327,6 +284,9 @@ class MeaningDataset(Dataset):
def len(self):
return len(self.seq)
def get_meaning_map(self):
return MeaningMap(self.end, self.vocab_size, self.max_subitem, self.min_subitem, self.use_cache)
def set_mask(self, level=None, idx=None):
if self.val_mask_level is not None and self.val_mask_idx is not None:
assert len(self.val_mask_level) > 0, "len must > 0"
@ -346,26 +306,17 @@ class MeaningDataset(Dataset):
output["input_ids"] = data
output["labels"] = data.clone()
output["token_type_ids"] = torch.zeros(data.shape)
output["tree"] = [self.tree[i] for i in idx_list]
output["val_mask"] = self.get_seq_mask_tensor(idx_list)
return output
def get_token(self, idx): # must equal sequence length
return self.seq[idx]
def get_tree(self, idx):
return self.tree[idx]
def print_tree(self, idx):
tokens = self.seq[idx]
tree = self.get_tree(idx)
s = str(tokens) + "\n"
s += MeaningMap.print_tree(tree)
return s
def get_meaning(self, idx):
return self.seq_meaning[idx]
def copy(self, start, end):
new = copy.deepcopy(self)
new.tree = new.tree[start:end]
new.seq = new.seq[start:end]
new.level = new.level[start:end]
new.rank_idx = new.rank_idx[start:end]
@ -380,12 +331,6 @@ class MeaningDataset(Dataset):
middle = int(l * ratio)
return self.copy(0, middle), self.copy(middle, l)
def token_frequency(self):
freq = {}
for t in self.tree:
MeaningMap.token_frequency(t, freq)
return freq
def get_seq_mask(self, idx, level, index):
# assert index < 15, "index must < 15"
# assert level < 8, "level must < 8"
@ -410,12 +355,12 @@ class MeaningDataset(Dataset):
class BatchGroupMeaningDataloader(Dataset):
def __init__(self, dataset: MeaningDataset, batch_size, shuffle=True, drop_last=True):
self.dataset = dataset
def __init__(self, meaning_dataset: MeaningDataset, batch_size, shuffle=True, drop_last=True):
self.meaning_dataset = meaning_dataset
self.batch_size = batch_size
self.drop_last = drop_last
seq_len = [len(s) for s in dataset.seq]
seq_len = [len(s) for s in meaning_dataset.seq]
unique, counts = np.unique(seq_len, return_counts=True)
gl = {}
for u in unique:
@ -451,16 +396,16 @@ class BatchGroupMeaningDataloader(Dataset):
return len(self.indexBatch)
def __getitem__(self, idx):
return self.dataset.get_batch(self.indexBatch[idx])
return self.meaning_dataset.get_batch(self.indexBatch[idx])
def get_tree(self, idx):
return [self.dataset.get_tree(i) for i in self.indexBatch[idx]]
return [self.meaning_dataset.get_tree(i) for i in self.indexBatch[idx]]
def print_tree(self, idx):
idx_list = self.indexBatch[idx]
s = "--------------------------------------------------------\n"
for i in idx_list:
s += self.dataset.print_tree(i)
s += self.meaning_dataset.print_tree(i)
s += "--------------------------------------------------------\n"
return s
@ -556,7 +501,6 @@ if __name__ == "__main__":
),
)
), "False"
freq = md.token_frequency()
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, min_subitem=2, use_cache=False)
md.set_mask([0, 1], [0, -1])
@ -572,21 +516,3 @@ if __name__ == "__main__":
print(m)
ne2 = next(it)
ne3 = next(it)
# map1 = dl.get_tree(0)
# map2 = dl.get_tree(1)
# print(dl.print_tree(0))
# dl = DataLoader(
# train,
# num_workers=1,
# persistent_workers=True,
# shuffle=False,
# )
# it = iter(dl)
# ne1 = next(it)
# ne2 = next(it)
# ne3 = next(it)
# for i in range(10):
# print(next(it)["input_ids"].numpy().tolist())

83
wit/dataset/node_tree.py Normal file
View File

@ -0,0 +1,83 @@
from anytree import Node, RenderTree
class NodeTree:
def __init__(self, tree: dict):
self.tree = tree
self.node = None
self.seq_node = []
def get_node(self):
def get_tree_node(tree, parent, seqlist):
for key, value in tree.items():
if isinstance(value, dict):
pn = Node(str(key), parent)
get_tree_node(value, pn, seqlist)
else:
pn = Node("<" + str(key) + ">", parent)
seqlist.append(pn)
if self.node:
return self.node
assert isinstance(self.tree, dict)
root = Node("root")
get_tree_node(self.tree, root, self.seq_node)
self.node = root.children[0] if len(self.tree) == 1 else root
return self.node
def print(self):
treestr = ""
for pre, fill, node in RenderTree(self.get_node()):
treestr += f"{pre}{node.name}\n"
print(treestr)
def print_tree(tree):
def get_tree_node(tree, parent=None):
for key, value in tree.items():
pn = Node(str(key), parent)
if isinstance(value, dict):
get_tree_node(value, pn)
assert isinstance(tree, dict)
root = Node("root")
get_tree_node(tree, root)
treestr = ""
for pre, fill, node in RenderTree(root.children[0] if len(tree) == 1 else root):
treestr += f"{pre}{node.name}\n"
return treestr
# def get_tree_indexed_str(tree, data, prefix):
# if isinstance(tree, list):
# base = ""
# qlen = 0
# for i, t in enumerate(tree):
# s, l = MeaningMap.get_tree_indexed_str(t, data[i], "")
# base += s
# qlen += l
# return (base, qlen)
# else:
# if isinstance(tree, dict):
# base = ""
# qlen = 0
# last_is_dict = None
# for key, value in tree.items():
# new_prefix = (len(str(key)) + 2) * " " + prefix
# dict_string, l = MeaningMap.get_tree_indexed_str(value, data[qlen:], new_prefix)
# if dict_string:
# base += "\n" + prefix + str(key) + ": " + dict_string
# last_is_dict = True
# else:
# base += "\n" + prefix + str(data[qlen]) + " " if last_is_dict else str(data[qlen]) + " "
# last_is_dict = False
# qlen += l
# return (base, qlen)
# return (None, 1)
# def token_frequency(tree, freq):
# if isinstance(tree, dict):
# for key, value in tree.items():
# if key in freq:
# freq[key] = freq[key] + 1
# else:
# freq[key] = 1
# MeaningMap.token_frequency(value, freq)

View File

@ -8,6 +8,7 @@ import numpy as np
import configuration
import dataset.dataset as ds
import dataset.node_tree as nt
if __name__ == "__main__":
@ -45,12 +46,15 @@ if __name__ == "__main__":
runner = QwenRunner(qwen.llm)
val = ds.InitValDataset(conf).dataset
data = val.dataset
item = data.get_token(0)
print(data.print_tree(0))
md = val.meaning_dataset
batch = torch.tensor([item[:-1]], dtype=torch.int64)
map = md.get_meaning_map()
item = md.get_token(0)
nt.NodeTree(map.get_tree(md.get_meaning(0))).print()
batch = torch.tensor([item[:-16]], dtype=torch.int64)
batch = batch.cuda()
print(item)
# print(item)
next_token = runner.ChatToken(batch)
print(next_token.detach().cpu().numpy())