Refine meaning dataset map.

This commit is contained in:
Colin 2025-02-22 16:50:16 +08:00
parent f0469b351c
commit 3a7ce45654
3 changed files with 119 additions and 106 deletions

View File

@ -183,66 +183,18 @@ class MeaningMap:
) )
def get_tree(self, meaning): # return meaning all sub items def get_tree(self, meaning): # return meaning all sub items
tree = {} def get_tree_dict(ms_map, vocab_size, meaning): # return meaning all sub items
ms = self.ms_map[meaning] tree = {}
for m in ms[ms > 0].tolist(): ms = ms_map[meaning]
tree[m] = self.get_tree(m) if m >= self.vocab_size else m for m in ms[ms > 0].tolist():
return tree tree[m] = get_tree_dict(ms_map, vocab_size, m) if m >= self.vocab_size else m
return tree
return get_tree_dict(self.ms_map, self.vocab_size, meaning)
def max_length(self): def max_length(self):
return max(self.ms_len) return max(self.ms_len)
def print_tree(tree):
def get_tree_node(tree, parent=None):
for key, value in tree.items():
pn = Node(str(key), parent)
if isinstance(value, dict):
get_tree_node(value, pn)
assert isinstance(tree, dict)
root = Node("root")
get_tree_node(tree, root)
treestr = ""
for pre, fill, node in RenderTree(root.children[0] if len(tree) == 1 else root):
treestr += f"{pre}{node.name}\n"
return treestr
def get_tree_indexed_str(tree, data, prefix):
if isinstance(tree, list):
base = ""
qlen = 0
for i, t in enumerate(tree):
s, l = MeaningMap.get_tree_indexed_str(t, data[i], "")
base += s
qlen += l
return (base, qlen)
else:
if isinstance(tree, dict):
base = ""
qlen = 0
last_is_dict = None
for key, value in tree.items():
new_prefix = (len(str(key)) + 2) * " " + prefix
dict_string, l = MeaningMap.get_tree_indexed_str(value, data[qlen:], new_prefix)
if dict_string:
base += "\n" + prefix + str(key) + ": " + dict_string
last_is_dict = True
else:
base += "\n" + prefix + str(data[qlen]) + " " if last_is_dict else str(data[qlen]) + " "
last_is_dict = False
qlen += l
return (base, qlen)
return (None, 1)
def token_frequency(tree, freq):
if isinstance(tree, dict):
for key, value in tree.items():
if key in freq:
freq[key] = freq[key] + 1
else:
freq[key] = 1
MeaningMap.token_frequency(value, freq)
class MeaningDataset(Dataset): class MeaningDataset(Dataset):
@ -259,17 +211,23 @@ class MeaningDataset(Dataset):
use_cache=True, use_cache=True,
): ):
np.random.seed(seed) np.random.seed(seed)
map = MeaningMap(end, vocab_size, max_subitem, min_subitem, use_cache=use_cache)
np.random.seed(seed) np.random.seed(seed)
self.start = start
self.end = end
self.vocab_size = vocab_size
self.max_subitem = max_subitem
self.min_subitem = min_subitem
self.use_cache = use_cache
self.min_seq_len = min_seq_len
print("Build MeaningDataset from MeaningMap.") print("Build MeaningDataset from MeaningMap.")
self.val_mask_level = None self.val_mask_level = None
self.val_mask_idx = None self.val_mask_idx = None
self.tree = []
self.seq = [] self.seq = []
self.level = [] self.level = []
self.rank_idx = [] self.rank_idx = []
self.rank_all = [] self.rank_all = []
self.seq_meaning = [] self.seq_meaning = []
map = self.get_meaning_map()
self.m_height = map.ms_height self.m_height = map.ms_height
self.m_weight = map.ms_weight self.m_weight = map.ms_weight
if size: if size:
@ -281,7 +239,6 @@ class MeaningDataset(Dataset):
for m in meanings: for m in meanings:
d, l, i, a = map.get_sequence(m) d, l, i, a = map.get_sequence(m)
if len(d) >= min_seq_len: if len(d) >= min_seq_len:
self.tree.append({m: map.get_tree(m)})
self.seq.append(d) self.seq.append(d)
self.level.append(l) self.level.append(l)
self.seq_meaning.append(m) self.seq_meaning.append(m)
@ -327,6 +284,9 @@ class MeaningDataset(Dataset):
def len(self): def len(self):
return len(self.seq) return len(self.seq)
def get_meaning_map(self):
return MeaningMap(self.end, self.vocab_size, self.max_subitem, self.min_subitem, self.use_cache)
def set_mask(self, level=None, idx=None): def set_mask(self, level=None, idx=None):
if self.val_mask_level is not None and self.val_mask_idx is not None: if self.val_mask_level is not None and self.val_mask_idx is not None:
assert len(self.val_mask_level) > 0, "len must > 0" assert len(self.val_mask_level) > 0, "len must > 0"
@ -346,26 +306,17 @@ class MeaningDataset(Dataset):
output["input_ids"] = data output["input_ids"] = data
output["labels"] = data.clone() output["labels"] = data.clone()
output["token_type_ids"] = torch.zeros(data.shape) output["token_type_ids"] = torch.zeros(data.shape)
output["tree"] = [self.tree[i] for i in idx_list]
output["val_mask"] = self.get_seq_mask_tensor(idx_list) output["val_mask"] = self.get_seq_mask_tensor(idx_list)
return output return output
def get_token(self, idx): # must equal sequence length def get_token(self, idx): # must equal sequence length
return self.seq[idx] return self.seq[idx]
def get_tree(self, idx): def get_meaning(self, idx):
return self.tree[idx] return self.seq_meaning[idx]
def print_tree(self, idx):
tokens = self.seq[idx]
tree = self.get_tree(idx)
s = str(tokens) + "\n"
s += MeaningMap.print_tree(tree)
return s
def copy(self, start, end): def copy(self, start, end):
new = copy.deepcopy(self) new = copy.deepcopy(self)
new.tree = new.tree[start:end]
new.seq = new.seq[start:end] new.seq = new.seq[start:end]
new.level = new.level[start:end] new.level = new.level[start:end]
new.rank_idx = new.rank_idx[start:end] new.rank_idx = new.rank_idx[start:end]
@ -380,12 +331,6 @@ class MeaningDataset(Dataset):
middle = int(l * ratio) middle = int(l * ratio)
return self.copy(0, middle), self.copy(middle, l) return self.copy(0, middle), self.copy(middle, l)
def token_frequency(self):
freq = {}
for t in self.tree:
MeaningMap.token_frequency(t, freq)
return freq
def get_seq_mask(self, idx, level, index): def get_seq_mask(self, idx, level, index):
# assert index < 15, "index must < 15" # assert index < 15, "index must < 15"
# assert level < 8, "level must < 8" # assert level < 8, "level must < 8"
@ -410,12 +355,12 @@ class MeaningDataset(Dataset):
class BatchGroupMeaningDataloader(Dataset): class BatchGroupMeaningDataloader(Dataset):
def __init__(self, dataset: MeaningDataset, batch_size, shuffle=True, drop_last=True): def __init__(self, meaning_dataset: MeaningDataset, batch_size, shuffle=True, drop_last=True):
self.dataset = dataset self.meaning_dataset = meaning_dataset
self.batch_size = batch_size self.batch_size = batch_size
self.drop_last = drop_last self.drop_last = drop_last
seq_len = [len(s) for s in dataset.seq] seq_len = [len(s) for s in meaning_dataset.seq]
unique, counts = np.unique(seq_len, return_counts=True) unique, counts = np.unique(seq_len, return_counts=True)
gl = {} gl = {}
for u in unique: for u in unique:
@ -451,16 +396,16 @@ class BatchGroupMeaningDataloader(Dataset):
return len(self.indexBatch) return len(self.indexBatch)
def __getitem__(self, idx): def __getitem__(self, idx):
return self.dataset.get_batch(self.indexBatch[idx]) return self.meaning_dataset.get_batch(self.indexBatch[idx])
def get_tree(self, idx): def get_tree(self, idx):
return [self.dataset.get_tree(i) for i in self.indexBatch[idx]] return [self.meaning_dataset.get_tree(i) for i in self.indexBatch[idx]]
def print_tree(self, idx): def print_tree(self, idx):
idx_list = self.indexBatch[idx] idx_list = self.indexBatch[idx]
s = "--------------------------------------------------------\n" s = "--------------------------------------------------------\n"
for i in idx_list: for i in idx_list:
s += self.dataset.print_tree(i) s += self.meaning_dataset.print_tree(i)
s += "--------------------------------------------------------\n" s += "--------------------------------------------------------\n"
return s return s
@ -556,7 +501,6 @@ if __name__ == "__main__":
), ),
) )
), "False" ), "False"
freq = md.token_frequency()
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, min_subitem=2, use_cache=False) md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, min_subitem=2, use_cache=False)
md.set_mask([0, 1], [0, -1]) md.set_mask([0, 1], [0, -1])
@ -572,21 +516,3 @@ if __name__ == "__main__":
print(m) print(m)
ne2 = next(it) ne2 = next(it)
ne3 = next(it) ne3 = next(it)
# map1 = dl.get_tree(0)
# map2 = dl.get_tree(1)
# print(dl.print_tree(0))
# dl = DataLoader(
# train,
# num_workers=1,
# persistent_workers=True,
# shuffle=False,
# )
# it = iter(dl)
# ne1 = next(it)
# ne2 = next(it)
# ne3 = next(it)
# for i in range(10):
# print(next(it)["input_ids"].numpy().tolist())

83
wit/dataset/node_tree.py Normal file
View File

@ -0,0 +1,83 @@
from anytree import Node, RenderTree
class NodeTree:
def __init__(self, tree: dict):
self.tree = tree
self.node = None
self.seq_node = []
def get_node(self):
def get_tree_node(tree, parent, seqlist):
for key, value in tree.items():
if isinstance(value, dict):
pn = Node(str(key), parent)
get_tree_node(value, pn, seqlist)
else:
pn = Node("<" + str(key) + ">", parent)
seqlist.append(pn)
if self.node:
return self.node
assert isinstance(self.tree, dict)
root = Node("root")
get_tree_node(self.tree, root, self.seq_node)
self.node = root.children[0] if len(self.tree) == 1 else root
return self.node
def print(self):
treestr = ""
for pre, fill, node in RenderTree(self.get_node()):
treestr += f"{pre}{node.name}\n"
print(treestr)
def print_tree(tree):
def get_tree_node(tree, parent=None):
for key, value in tree.items():
pn = Node(str(key), parent)
if isinstance(value, dict):
get_tree_node(value, pn)
assert isinstance(tree, dict)
root = Node("root")
get_tree_node(tree, root)
treestr = ""
for pre, fill, node in RenderTree(root.children[0] if len(tree) == 1 else root):
treestr += f"{pre}{node.name}\n"
return treestr
# def get_tree_indexed_str(tree, data, prefix):
# if isinstance(tree, list):
# base = ""
# qlen = 0
# for i, t in enumerate(tree):
# s, l = MeaningMap.get_tree_indexed_str(t, data[i], "")
# base += s
# qlen += l
# return (base, qlen)
# else:
# if isinstance(tree, dict):
# base = ""
# qlen = 0
# last_is_dict = None
# for key, value in tree.items():
# new_prefix = (len(str(key)) + 2) * " " + prefix
# dict_string, l = MeaningMap.get_tree_indexed_str(value, data[qlen:], new_prefix)
# if dict_string:
# base += "\n" + prefix + str(key) + ": " + dict_string
# last_is_dict = True
# else:
# base += "\n" + prefix + str(data[qlen]) + " " if last_is_dict else str(data[qlen]) + " "
# last_is_dict = False
# qlen += l
# return (base, qlen)
# return (None, 1)
# def token_frequency(tree, freq):
# if isinstance(tree, dict):
# for key, value in tree.items():
# if key in freq:
# freq[key] = freq[key] + 1
# else:
# freq[key] = 1
# MeaningMap.token_frequency(value, freq)

View File

@ -8,6 +8,7 @@ import numpy as np
import configuration import configuration
import dataset.dataset as ds import dataset.dataset as ds
import dataset.node_tree as nt
if __name__ == "__main__": if __name__ == "__main__":
@ -45,12 +46,15 @@ if __name__ == "__main__":
runner = QwenRunner(qwen.llm) runner = QwenRunner(qwen.llm)
val = ds.InitValDataset(conf).dataset val = ds.InitValDataset(conf).dataset
data = val.dataset md = val.meaning_dataset
item = data.get_token(0)
print(data.print_tree(0))
batch = torch.tensor([item[:-1]], dtype=torch.int64) map = md.get_meaning_map()
item = md.get_token(0)
nt.NodeTree(map.get_tree(md.get_meaning(0))).print()
batch = torch.tensor([item[:-16]], dtype=torch.int64)
batch = batch.cuda() batch = batch.cuda()
print(item) # print(item)
next_token = runner.ChatToken(batch) next_token = runner.ChatToken(batch)
print(next_token.detach().cpu().numpy()) print(next_token.detach().cpu().numpy())