Refine meaning dataset map.
This commit is contained in:
parent
f0469b351c
commit
3a7ce45654
|
@ -183,66 +183,18 @@ class MeaningMap:
|
|||
)
|
||||
|
||||
def get_tree(self, meaning): # return meaning all sub items
|
||||
tree = {}
|
||||
ms = self.ms_map[meaning]
|
||||
for m in ms[ms > 0].tolist():
|
||||
tree[m] = self.get_tree(m) if m >= self.vocab_size else m
|
||||
return tree
|
||||
def get_tree_dict(ms_map, vocab_size, meaning): # return meaning all sub items
|
||||
tree = {}
|
||||
ms = ms_map[meaning]
|
||||
for m in ms[ms > 0].tolist():
|
||||
tree[m] = get_tree_dict(ms_map, vocab_size, m) if m >= self.vocab_size else m
|
||||
return tree
|
||||
|
||||
return get_tree_dict(self.ms_map, self.vocab_size, meaning)
|
||||
|
||||
def max_length(self):
|
||||
return max(self.ms_len)
|
||||
|
||||
def print_tree(tree):
|
||||
def get_tree_node(tree, parent=None):
|
||||
for key, value in tree.items():
|
||||
pn = Node(str(key), parent)
|
||||
if isinstance(value, dict):
|
||||
get_tree_node(value, pn)
|
||||
|
||||
assert isinstance(tree, dict)
|
||||
root = Node("root")
|
||||
get_tree_node(tree, root)
|
||||
treestr = ""
|
||||
for pre, fill, node in RenderTree(root.children[0] if len(tree) == 1 else root):
|
||||
treestr += f"{pre}{node.name}\n"
|
||||
return treestr
|
||||
|
||||
def get_tree_indexed_str(tree, data, prefix):
|
||||
if isinstance(tree, list):
|
||||
base = ""
|
||||
qlen = 0
|
||||
for i, t in enumerate(tree):
|
||||
s, l = MeaningMap.get_tree_indexed_str(t, data[i], "")
|
||||
base += s
|
||||
qlen += l
|
||||
return (base, qlen)
|
||||
else:
|
||||
if isinstance(tree, dict):
|
||||
base = ""
|
||||
qlen = 0
|
||||
last_is_dict = None
|
||||
for key, value in tree.items():
|
||||
new_prefix = (len(str(key)) + 2) * " " + prefix
|
||||
dict_string, l = MeaningMap.get_tree_indexed_str(value, data[qlen:], new_prefix)
|
||||
if dict_string:
|
||||
base += "\n" + prefix + str(key) + ": " + dict_string
|
||||
last_is_dict = True
|
||||
else:
|
||||
base += "\n" + prefix + str(data[qlen]) + " " if last_is_dict else str(data[qlen]) + " "
|
||||
last_is_dict = False
|
||||
qlen += l
|
||||
return (base, qlen)
|
||||
return (None, 1)
|
||||
|
||||
def token_frequency(tree, freq):
|
||||
if isinstance(tree, dict):
|
||||
for key, value in tree.items():
|
||||
if key in freq:
|
||||
freq[key] = freq[key] + 1
|
||||
else:
|
||||
freq[key] = 1
|
||||
MeaningMap.token_frequency(value, freq)
|
||||
|
||||
|
||||
class MeaningDataset(Dataset):
|
||||
|
||||
|
@ -259,17 +211,23 @@ class MeaningDataset(Dataset):
|
|||
use_cache=True,
|
||||
):
|
||||
np.random.seed(seed)
|
||||
map = MeaningMap(end, vocab_size, max_subitem, min_subitem, use_cache=use_cache)
|
||||
np.random.seed(seed)
|
||||
self.start = start
|
||||
self.end = end
|
||||
self.vocab_size = vocab_size
|
||||
self.max_subitem = max_subitem
|
||||
self.min_subitem = min_subitem
|
||||
self.use_cache = use_cache
|
||||
self.min_seq_len = min_seq_len
|
||||
print("Build MeaningDataset from MeaningMap.")
|
||||
self.val_mask_level = None
|
||||
self.val_mask_idx = None
|
||||
self.tree = []
|
||||
self.seq = []
|
||||
self.level = []
|
||||
self.rank_idx = []
|
||||
self.rank_all = []
|
||||
self.seq_meaning = []
|
||||
map = self.get_meaning_map()
|
||||
self.m_height = map.ms_height
|
||||
self.m_weight = map.ms_weight
|
||||
if size:
|
||||
|
@ -281,7 +239,6 @@ class MeaningDataset(Dataset):
|
|||
for m in meanings:
|
||||
d, l, i, a = map.get_sequence(m)
|
||||
if len(d) >= min_seq_len:
|
||||
self.tree.append({m: map.get_tree(m)})
|
||||
self.seq.append(d)
|
||||
self.level.append(l)
|
||||
self.seq_meaning.append(m)
|
||||
|
@ -327,6 +284,9 @@ class MeaningDataset(Dataset):
|
|||
def len(self):
|
||||
return len(self.seq)
|
||||
|
||||
def get_meaning_map(self):
|
||||
return MeaningMap(self.end, self.vocab_size, self.max_subitem, self.min_subitem, self.use_cache)
|
||||
|
||||
def set_mask(self, level=None, idx=None):
|
||||
if self.val_mask_level is not None and self.val_mask_idx is not None:
|
||||
assert len(self.val_mask_level) > 0, "len must > 0"
|
||||
|
@ -346,26 +306,17 @@ class MeaningDataset(Dataset):
|
|||
output["input_ids"] = data
|
||||
output["labels"] = data.clone()
|
||||
output["token_type_ids"] = torch.zeros(data.shape)
|
||||
output["tree"] = [self.tree[i] for i in idx_list]
|
||||
output["val_mask"] = self.get_seq_mask_tensor(idx_list)
|
||||
return output
|
||||
|
||||
def get_token(self, idx): # must equal sequence length
|
||||
return self.seq[idx]
|
||||
|
||||
def get_tree(self, idx):
|
||||
return self.tree[idx]
|
||||
|
||||
def print_tree(self, idx):
|
||||
tokens = self.seq[idx]
|
||||
tree = self.get_tree(idx)
|
||||
s = str(tokens) + "\n"
|
||||
s += MeaningMap.print_tree(tree)
|
||||
return s
|
||||
def get_meaning(self, idx):
|
||||
return self.seq_meaning[idx]
|
||||
|
||||
def copy(self, start, end):
|
||||
new = copy.deepcopy(self)
|
||||
new.tree = new.tree[start:end]
|
||||
new.seq = new.seq[start:end]
|
||||
new.level = new.level[start:end]
|
||||
new.rank_idx = new.rank_idx[start:end]
|
||||
|
@ -380,12 +331,6 @@ class MeaningDataset(Dataset):
|
|||
middle = int(l * ratio)
|
||||
return self.copy(0, middle), self.copy(middle, l)
|
||||
|
||||
def token_frequency(self):
|
||||
freq = {}
|
||||
for t in self.tree:
|
||||
MeaningMap.token_frequency(t, freq)
|
||||
return freq
|
||||
|
||||
def get_seq_mask(self, idx, level, index):
|
||||
# assert index < 15, "index must < 15"
|
||||
# assert level < 8, "level must < 8"
|
||||
|
@ -410,12 +355,12 @@ class MeaningDataset(Dataset):
|
|||
|
||||
|
||||
class BatchGroupMeaningDataloader(Dataset):
|
||||
def __init__(self, dataset: MeaningDataset, batch_size, shuffle=True, drop_last=True):
|
||||
self.dataset = dataset
|
||||
def __init__(self, meaning_dataset: MeaningDataset, batch_size, shuffle=True, drop_last=True):
|
||||
self.meaning_dataset = meaning_dataset
|
||||
self.batch_size = batch_size
|
||||
self.drop_last = drop_last
|
||||
|
||||
seq_len = [len(s) for s in dataset.seq]
|
||||
seq_len = [len(s) for s in meaning_dataset.seq]
|
||||
unique, counts = np.unique(seq_len, return_counts=True)
|
||||
gl = {}
|
||||
for u in unique:
|
||||
|
@ -451,16 +396,16 @@ class BatchGroupMeaningDataloader(Dataset):
|
|||
return len(self.indexBatch)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.dataset.get_batch(self.indexBatch[idx])
|
||||
return self.meaning_dataset.get_batch(self.indexBatch[idx])
|
||||
|
||||
def get_tree(self, idx):
|
||||
return [self.dataset.get_tree(i) for i in self.indexBatch[idx]]
|
||||
return [self.meaning_dataset.get_tree(i) for i in self.indexBatch[idx]]
|
||||
|
||||
def print_tree(self, idx):
|
||||
idx_list = self.indexBatch[idx]
|
||||
s = "--------------------------------------------------------\n"
|
||||
for i in idx_list:
|
||||
s += self.dataset.print_tree(i)
|
||||
s += self.meaning_dataset.print_tree(i)
|
||||
s += "--------------------------------------------------------\n"
|
||||
return s
|
||||
|
||||
|
@ -556,7 +501,6 @@ if __name__ == "__main__":
|
|||
),
|
||||
)
|
||||
), "False"
|
||||
freq = md.token_frequency()
|
||||
|
||||
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, min_subitem=2, use_cache=False)
|
||||
md.set_mask([0, 1], [0, -1])
|
||||
|
@ -572,21 +516,3 @@ if __name__ == "__main__":
|
|||
print(m)
|
||||
ne2 = next(it)
|
||||
ne3 = next(it)
|
||||
|
||||
# map1 = dl.get_tree(0)
|
||||
# map2 = dl.get_tree(1)
|
||||
# print(dl.print_tree(0))
|
||||
|
||||
# dl = DataLoader(
|
||||
# train,
|
||||
# num_workers=1,
|
||||
# persistent_workers=True,
|
||||
# shuffle=False,
|
||||
# )
|
||||
# it = iter(dl)
|
||||
# ne1 = next(it)
|
||||
# ne2 = next(it)
|
||||
# ne3 = next(it)
|
||||
|
||||
# for i in range(10):
|
||||
# print(next(it)["input_ids"].numpy().tolist())
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
from anytree import Node, RenderTree
|
||||
|
||||
|
||||
class NodeTree:
|
||||
def __init__(self, tree: dict):
|
||||
self.tree = tree
|
||||
self.node = None
|
||||
self.seq_node = []
|
||||
|
||||
def get_node(self):
|
||||
def get_tree_node(tree, parent, seqlist):
|
||||
for key, value in tree.items():
|
||||
if isinstance(value, dict):
|
||||
pn = Node(str(key), parent)
|
||||
get_tree_node(value, pn, seqlist)
|
||||
else:
|
||||
pn = Node("<" + str(key) + ">", parent)
|
||||
seqlist.append(pn)
|
||||
|
||||
if self.node:
|
||||
return self.node
|
||||
assert isinstance(self.tree, dict)
|
||||
root = Node("root")
|
||||
get_tree_node(self.tree, root, self.seq_node)
|
||||
self.node = root.children[0] if len(self.tree) == 1 else root
|
||||
return self.node
|
||||
|
||||
def print(self):
|
||||
treestr = ""
|
||||
for pre, fill, node in RenderTree(self.get_node()):
|
||||
treestr += f"{pre}{node.name}\n"
|
||||
print(treestr)
|
||||
|
||||
def print_tree(tree):
|
||||
def get_tree_node(tree, parent=None):
|
||||
for key, value in tree.items():
|
||||
pn = Node(str(key), parent)
|
||||
if isinstance(value, dict):
|
||||
get_tree_node(value, pn)
|
||||
|
||||
assert isinstance(tree, dict)
|
||||
root = Node("root")
|
||||
get_tree_node(tree, root)
|
||||
treestr = ""
|
||||
for pre, fill, node in RenderTree(root.children[0] if len(tree) == 1 else root):
|
||||
treestr += f"{pre}{node.name}\n"
|
||||
return treestr
|
||||
|
||||
# def get_tree_indexed_str(tree, data, prefix):
|
||||
# if isinstance(tree, list):
|
||||
# base = ""
|
||||
# qlen = 0
|
||||
# for i, t in enumerate(tree):
|
||||
# s, l = MeaningMap.get_tree_indexed_str(t, data[i], "")
|
||||
# base += s
|
||||
# qlen += l
|
||||
# return (base, qlen)
|
||||
# else:
|
||||
# if isinstance(tree, dict):
|
||||
# base = ""
|
||||
# qlen = 0
|
||||
# last_is_dict = None
|
||||
# for key, value in tree.items():
|
||||
# new_prefix = (len(str(key)) + 2) * " " + prefix
|
||||
# dict_string, l = MeaningMap.get_tree_indexed_str(value, data[qlen:], new_prefix)
|
||||
# if dict_string:
|
||||
# base += "\n" + prefix + str(key) + ": " + dict_string
|
||||
# last_is_dict = True
|
||||
# else:
|
||||
# base += "\n" + prefix + str(data[qlen]) + " " if last_is_dict else str(data[qlen]) + " "
|
||||
# last_is_dict = False
|
||||
# qlen += l
|
||||
# return (base, qlen)
|
||||
# return (None, 1)
|
||||
|
||||
# def token_frequency(tree, freq):
|
||||
# if isinstance(tree, dict):
|
||||
# for key, value in tree.items():
|
||||
# if key in freq:
|
||||
# freq[key] = freq[key] + 1
|
||||
# else:
|
||||
# freq[key] = 1
|
||||
# MeaningMap.token_frequency(value, freq)
|
|
@ -8,6 +8,7 @@ import numpy as np
|
|||
|
||||
import configuration
|
||||
import dataset.dataset as ds
|
||||
import dataset.node_tree as nt
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
|
@ -45,12 +46,15 @@ if __name__ == "__main__":
|
|||
runner = QwenRunner(qwen.llm)
|
||||
|
||||
val = ds.InitValDataset(conf).dataset
|
||||
data = val.dataset
|
||||
item = data.get_token(0)
|
||||
print(data.print_tree(0))
|
||||
md = val.meaning_dataset
|
||||
|
||||
batch = torch.tensor([item[:-1]], dtype=torch.int64)
|
||||
map = md.get_meaning_map()
|
||||
|
||||
item = md.get_token(0)
|
||||
nt.NodeTree(map.get_tree(md.get_meaning(0))).print()
|
||||
|
||||
batch = torch.tensor([item[:-16]], dtype=torch.int64)
|
||||
batch = batch.cuda()
|
||||
print(item)
|
||||
# print(item)
|
||||
next_token = runner.ChatToken(batch)
|
||||
print(next_token.detach().cpu().numpy())
|
||||
|
|
Loading…
Reference in New Issue