Refine meaning dataset map.
This commit is contained in:
parent
f0469b351c
commit
3a7ce45654
|
@ -183,66 +183,18 @@ class MeaningMap:
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_tree(self, meaning): # return meaning all sub items
|
def get_tree(self, meaning): # return meaning all sub items
|
||||||
|
def get_tree_dict(ms_map, vocab_size, meaning): # return meaning all sub items
|
||||||
tree = {}
|
tree = {}
|
||||||
ms = self.ms_map[meaning]
|
ms = ms_map[meaning]
|
||||||
for m in ms[ms > 0].tolist():
|
for m in ms[ms > 0].tolist():
|
||||||
tree[m] = self.get_tree(m) if m >= self.vocab_size else m
|
tree[m] = get_tree_dict(ms_map, vocab_size, m) if m >= self.vocab_size else m
|
||||||
return tree
|
return tree
|
||||||
|
|
||||||
|
return get_tree_dict(self.ms_map, self.vocab_size, meaning)
|
||||||
|
|
||||||
def max_length(self):
|
def max_length(self):
|
||||||
return max(self.ms_len)
|
return max(self.ms_len)
|
||||||
|
|
||||||
def print_tree(tree):
|
|
||||||
def get_tree_node(tree, parent=None):
|
|
||||||
for key, value in tree.items():
|
|
||||||
pn = Node(str(key), parent)
|
|
||||||
if isinstance(value, dict):
|
|
||||||
get_tree_node(value, pn)
|
|
||||||
|
|
||||||
assert isinstance(tree, dict)
|
|
||||||
root = Node("root")
|
|
||||||
get_tree_node(tree, root)
|
|
||||||
treestr = ""
|
|
||||||
for pre, fill, node in RenderTree(root.children[0] if len(tree) == 1 else root):
|
|
||||||
treestr += f"{pre}{node.name}\n"
|
|
||||||
return treestr
|
|
||||||
|
|
||||||
def get_tree_indexed_str(tree, data, prefix):
|
|
||||||
if isinstance(tree, list):
|
|
||||||
base = ""
|
|
||||||
qlen = 0
|
|
||||||
for i, t in enumerate(tree):
|
|
||||||
s, l = MeaningMap.get_tree_indexed_str(t, data[i], "")
|
|
||||||
base += s
|
|
||||||
qlen += l
|
|
||||||
return (base, qlen)
|
|
||||||
else:
|
|
||||||
if isinstance(tree, dict):
|
|
||||||
base = ""
|
|
||||||
qlen = 0
|
|
||||||
last_is_dict = None
|
|
||||||
for key, value in tree.items():
|
|
||||||
new_prefix = (len(str(key)) + 2) * " " + prefix
|
|
||||||
dict_string, l = MeaningMap.get_tree_indexed_str(value, data[qlen:], new_prefix)
|
|
||||||
if dict_string:
|
|
||||||
base += "\n" + prefix + str(key) + ": " + dict_string
|
|
||||||
last_is_dict = True
|
|
||||||
else:
|
|
||||||
base += "\n" + prefix + str(data[qlen]) + " " if last_is_dict else str(data[qlen]) + " "
|
|
||||||
last_is_dict = False
|
|
||||||
qlen += l
|
|
||||||
return (base, qlen)
|
|
||||||
return (None, 1)
|
|
||||||
|
|
||||||
def token_frequency(tree, freq):
|
|
||||||
if isinstance(tree, dict):
|
|
||||||
for key, value in tree.items():
|
|
||||||
if key in freq:
|
|
||||||
freq[key] = freq[key] + 1
|
|
||||||
else:
|
|
||||||
freq[key] = 1
|
|
||||||
MeaningMap.token_frequency(value, freq)
|
|
||||||
|
|
||||||
|
|
||||||
class MeaningDataset(Dataset):
|
class MeaningDataset(Dataset):
|
||||||
|
|
||||||
|
@ -259,17 +211,23 @@ class MeaningDataset(Dataset):
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
):
|
):
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
map = MeaningMap(end, vocab_size, max_subitem, min_subitem, use_cache=use_cache)
|
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
|
self.start = start
|
||||||
|
self.end = end
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_subitem = max_subitem
|
||||||
|
self.min_subitem = min_subitem
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.min_seq_len = min_seq_len
|
||||||
print("Build MeaningDataset from MeaningMap.")
|
print("Build MeaningDataset from MeaningMap.")
|
||||||
self.val_mask_level = None
|
self.val_mask_level = None
|
||||||
self.val_mask_idx = None
|
self.val_mask_idx = None
|
||||||
self.tree = []
|
|
||||||
self.seq = []
|
self.seq = []
|
||||||
self.level = []
|
self.level = []
|
||||||
self.rank_idx = []
|
self.rank_idx = []
|
||||||
self.rank_all = []
|
self.rank_all = []
|
||||||
self.seq_meaning = []
|
self.seq_meaning = []
|
||||||
|
map = self.get_meaning_map()
|
||||||
self.m_height = map.ms_height
|
self.m_height = map.ms_height
|
||||||
self.m_weight = map.ms_weight
|
self.m_weight = map.ms_weight
|
||||||
if size:
|
if size:
|
||||||
|
@ -281,7 +239,6 @@ class MeaningDataset(Dataset):
|
||||||
for m in meanings:
|
for m in meanings:
|
||||||
d, l, i, a = map.get_sequence(m)
|
d, l, i, a = map.get_sequence(m)
|
||||||
if len(d) >= min_seq_len:
|
if len(d) >= min_seq_len:
|
||||||
self.tree.append({m: map.get_tree(m)})
|
|
||||||
self.seq.append(d)
|
self.seq.append(d)
|
||||||
self.level.append(l)
|
self.level.append(l)
|
||||||
self.seq_meaning.append(m)
|
self.seq_meaning.append(m)
|
||||||
|
@ -327,6 +284,9 @@ class MeaningDataset(Dataset):
|
||||||
def len(self):
|
def len(self):
|
||||||
return len(self.seq)
|
return len(self.seq)
|
||||||
|
|
||||||
|
def get_meaning_map(self):
|
||||||
|
return MeaningMap(self.end, self.vocab_size, self.max_subitem, self.min_subitem, self.use_cache)
|
||||||
|
|
||||||
def set_mask(self, level=None, idx=None):
|
def set_mask(self, level=None, idx=None):
|
||||||
if self.val_mask_level is not None and self.val_mask_idx is not None:
|
if self.val_mask_level is not None and self.val_mask_idx is not None:
|
||||||
assert len(self.val_mask_level) > 0, "len must > 0"
|
assert len(self.val_mask_level) > 0, "len must > 0"
|
||||||
|
@ -346,26 +306,17 @@ class MeaningDataset(Dataset):
|
||||||
output["input_ids"] = data
|
output["input_ids"] = data
|
||||||
output["labels"] = data.clone()
|
output["labels"] = data.clone()
|
||||||
output["token_type_ids"] = torch.zeros(data.shape)
|
output["token_type_ids"] = torch.zeros(data.shape)
|
||||||
output["tree"] = [self.tree[i] for i in idx_list]
|
|
||||||
output["val_mask"] = self.get_seq_mask_tensor(idx_list)
|
output["val_mask"] = self.get_seq_mask_tensor(idx_list)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def get_token(self, idx): # must equal sequence length
|
def get_token(self, idx): # must equal sequence length
|
||||||
return self.seq[idx]
|
return self.seq[idx]
|
||||||
|
|
||||||
def get_tree(self, idx):
|
def get_meaning(self, idx):
|
||||||
return self.tree[idx]
|
return self.seq_meaning[idx]
|
||||||
|
|
||||||
def print_tree(self, idx):
|
|
||||||
tokens = self.seq[idx]
|
|
||||||
tree = self.get_tree(idx)
|
|
||||||
s = str(tokens) + "\n"
|
|
||||||
s += MeaningMap.print_tree(tree)
|
|
||||||
return s
|
|
||||||
|
|
||||||
def copy(self, start, end):
|
def copy(self, start, end):
|
||||||
new = copy.deepcopy(self)
|
new = copy.deepcopy(self)
|
||||||
new.tree = new.tree[start:end]
|
|
||||||
new.seq = new.seq[start:end]
|
new.seq = new.seq[start:end]
|
||||||
new.level = new.level[start:end]
|
new.level = new.level[start:end]
|
||||||
new.rank_idx = new.rank_idx[start:end]
|
new.rank_idx = new.rank_idx[start:end]
|
||||||
|
@ -380,12 +331,6 @@ class MeaningDataset(Dataset):
|
||||||
middle = int(l * ratio)
|
middle = int(l * ratio)
|
||||||
return self.copy(0, middle), self.copy(middle, l)
|
return self.copy(0, middle), self.copy(middle, l)
|
||||||
|
|
||||||
def token_frequency(self):
|
|
||||||
freq = {}
|
|
||||||
for t in self.tree:
|
|
||||||
MeaningMap.token_frequency(t, freq)
|
|
||||||
return freq
|
|
||||||
|
|
||||||
def get_seq_mask(self, idx, level, index):
|
def get_seq_mask(self, idx, level, index):
|
||||||
# assert index < 15, "index must < 15"
|
# assert index < 15, "index must < 15"
|
||||||
# assert level < 8, "level must < 8"
|
# assert level < 8, "level must < 8"
|
||||||
|
@ -410,12 +355,12 @@ class MeaningDataset(Dataset):
|
||||||
|
|
||||||
|
|
||||||
class BatchGroupMeaningDataloader(Dataset):
|
class BatchGroupMeaningDataloader(Dataset):
|
||||||
def __init__(self, dataset: MeaningDataset, batch_size, shuffle=True, drop_last=True):
|
def __init__(self, meaning_dataset: MeaningDataset, batch_size, shuffle=True, drop_last=True):
|
||||||
self.dataset = dataset
|
self.meaning_dataset = meaning_dataset
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.drop_last = drop_last
|
self.drop_last = drop_last
|
||||||
|
|
||||||
seq_len = [len(s) for s in dataset.seq]
|
seq_len = [len(s) for s in meaning_dataset.seq]
|
||||||
unique, counts = np.unique(seq_len, return_counts=True)
|
unique, counts = np.unique(seq_len, return_counts=True)
|
||||||
gl = {}
|
gl = {}
|
||||||
for u in unique:
|
for u in unique:
|
||||||
|
@ -451,16 +396,16 @@ class BatchGroupMeaningDataloader(Dataset):
|
||||||
return len(self.indexBatch)
|
return len(self.indexBatch)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
return self.dataset.get_batch(self.indexBatch[idx])
|
return self.meaning_dataset.get_batch(self.indexBatch[idx])
|
||||||
|
|
||||||
def get_tree(self, idx):
|
def get_tree(self, idx):
|
||||||
return [self.dataset.get_tree(i) for i in self.indexBatch[idx]]
|
return [self.meaning_dataset.get_tree(i) for i in self.indexBatch[idx]]
|
||||||
|
|
||||||
def print_tree(self, idx):
|
def print_tree(self, idx):
|
||||||
idx_list = self.indexBatch[idx]
|
idx_list = self.indexBatch[idx]
|
||||||
s = "--------------------------------------------------------\n"
|
s = "--------------------------------------------------------\n"
|
||||||
for i in idx_list:
|
for i in idx_list:
|
||||||
s += self.dataset.print_tree(i)
|
s += self.meaning_dataset.print_tree(i)
|
||||||
s += "--------------------------------------------------------\n"
|
s += "--------------------------------------------------------\n"
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
@ -556,7 +501,6 @@ if __name__ == "__main__":
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
), "False"
|
), "False"
|
||||||
freq = md.token_frequency()
|
|
||||||
|
|
||||||
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, min_subitem=2, use_cache=False)
|
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, min_subitem=2, use_cache=False)
|
||||||
md.set_mask([0, 1], [0, -1])
|
md.set_mask([0, 1], [0, -1])
|
||||||
|
@ -572,21 +516,3 @@ if __name__ == "__main__":
|
||||||
print(m)
|
print(m)
|
||||||
ne2 = next(it)
|
ne2 = next(it)
|
||||||
ne3 = next(it)
|
ne3 = next(it)
|
||||||
|
|
||||||
# map1 = dl.get_tree(0)
|
|
||||||
# map2 = dl.get_tree(1)
|
|
||||||
# print(dl.print_tree(0))
|
|
||||||
|
|
||||||
# dl = DataLoader(
|
|
||||||
# train,
|
|
||||||
# num_workers=1,
|
|
||||||
# persistent_workers=True,
|
|
||||||
# shuffle=False,
|
|
||||||
# )
|
|
||||||
# it = iter(dl)
|
|
||||||
# ne1 = next(it)
|
|
||||||
# ne2 = next(it)
|
|
||||||
# ne3 = next(it)
|
|
||||||
|
|
||||||
# for i in range(10):
|
|
||||||
# print(next(it)["input_ids"].numpy().tolist())
|
|
||||||
|
|
|
@ -0,0 +1,83 @@
|
||||||
|
from anytree import Node, RenderTree
|
||||||
|
|
||||||
|
|
||||||
|
class NodeTree:
|
||||||
|
def __init__(self, tree: dict):
|
||||||
|
self.tree = tree
|
||||||
|
self.node = None
|
||||||
|
self.seq_node = []
|
||||||
|
|
||||||
|
def get_node(self):
|
||||||
|
def get_tree_node(tree, parent, seqlist):
|
||||||
|
for key, value in tree.items():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
pn = Node(str(key), parent)
|
||||||
|
get_tree_node(value, pn, seqlist)
|
||||||
|
else:
|
||||||
|
pn = Node("<" + str(key) + ">", parent)
|
||||||
|
seqlist.append(pn)
|
||||||
|
|
||||||
|
if self.node:
|
||||||
|
return self.node
|
||||||
|
assert isinstance(self.tree, dict)
|
||||||
|
root = Node("root")
|
||||||
|
get_tree_node(self.tree, root, self.seq_node)
|
||||||
|
self.node = root.children[0] if len(self.tree) == 1 else root
|
||||||
|
return self.node
|
||||||
|
|
||||||
|
def print(self):
|
||||||
|
treestr = ""
|
||||||
|
for pre, fill, node in RenderTree(self.get_node()):
|
||||||
|
treestr += f"{pre}{node.name}\n"
|
||||||
|
print(treestr)
|
||||||
|
|
||||||
|
def print_tree(tree):
|
||||||
|
def get_tree_node(tree, parent=None):
|
||||||
|
for key, value in tree.items():
|
||||||
|
pn = Node(str(key), parent)
|
||||||
|
if isinstance(value, dict):
|
||||||
|
get_tree_node(value, pn)
|
||||||
|
|
||||||
|
assert isinstance(tree, dict)
|
||||||
|
root = Node("root")
|
||||||
|
get_tree_node(tree, root)
|
||||||
|
treestr = ""
|
||||||
|
for pre, fill, node in RenderTree(root.children[0] if len(tree) == 1 else root):
|
||||||
|
treestr += f"{pre}{node.name}\n"
|
||||||
|
return treestr
|
||||||
|
|
||||||
|
# def get_tree_indexed_str(tree, data, prefix):
|
||||||
|
# if isinstance(tree, list):
|
||||||
|
# base = ""
|
||||||
|
# qlen = 0
|
||||||
|
# for i, t in enumerate(tree):
|
||||||
|
# s, l = MeaningMap.get_tree_indexed_str(t, data[i], "")
|
||||||
|
# base += s
|
||||||
|
# qlen += l
|
||||||
|
# return (base, qlen)
|
||||||
|
# else:
|
||||||
|
# if isinstance(tree, dict):
|
||||||
|
# base = ""
|
||||||
|
# qlen = 0
|
||||||
|
# last_is_dict = None
|
||||||
|
# for key, value in tree.items():
|
||||||
|
# new_prefix = (len(str(key)) + 2) * " " + prefix
|
||||||
|
# dict_string, l = MeaningMap.get_tree_indexed_str(value, data[qlen:], new_prefix)
|
||||||
|
# if dict_string:
|
||||||
|
# base += "\n" + prefix + str(key) + ": " + dict_string
|
||||||
|
# last_is_dict = True
|
||||||
|
# else:
|
||||||
|
# base += "\n" + prefix + str(data[qlen]) + " " if last_is_dict else str(data[qlen]) + " "
|
||||||
|
# last_is_dict = False
|
||||||
|
# qlen += l
|
||||||
|
# return (base, qlen)
|
||||||
|
# return (None, 1)
|
||||||
|
|
||||||
|
# def token_frequency(tree, freq):
|
||||||
|
# if isinstance(tree, dict):
|
||||||
|
# for key, value in tree.items():
|
||||||
|
# if key in freq:
|
||||||
|
# freq[key] = freq[key] + 1
|
||||||
|
# else:
|
||||||
|
# freq[key] = 1
|
||||||
|
# MeaningMap.token_frequency(value, freq)
|
|
@ -8,6 +8,7 @@ import numpy as np
|
||||||
|
|
||||||
import configuration
|
import configuration
|
||||||
import dataset.dataset as ds
|
import dataset.dataset as ds
|
||||||
|
import dataset.node_tree as nt
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
@ -45,12 +46,15 @@ if __name__ == "__main__":
|
||||||
runner = QwenRunner(qwen.llm)
|
runner = QwenRunner(qwen.llm)
|
||||||
|
|
||||||
val = ds.InitValDataset(conf).dataset
|
val = ds.InitValDataset(conf).dataset
|
||||||
data = val.dataset
|
md = val.meaning_dataset
|
||||||
item = data.get_token(0)
|
|
||||||
print(data.print_tree(0))
|
|
||||||
|
|
||||||
batch = torch.tensor([item[:-1]], dtype=torch.int64)
|
map = md.get_meaning_map()
|
||||||
|
|
||||||
|
item = md.get_token(0)
|
||||||
|
nt.NodeTree(map.get_tree(md.get_meaning(0))).print()
|
||||||
|
|
||||||
|
batch = torch.tensor([item[:-16]], dtype=torch.int64)
|
||||||
batch = batch.cuda()
|
batch = batch.cuda()
|
||||||
print(item)
|
# print(item)
|
||||||
next_token = runner.ChatToken(batch)
|
next_token = runner.ChatToken(batch)
|
||||||
print(next_token.detach().cpu().numpy())
|
print(next_token.detach().cpu().numpy())
|
||||||
|
|
Loading…
Reference in New Issue