Refine meaning dataset map.
This commit is contained in:
		
							parent
							
								
									f0469b351c
								
							
						
					
					
						commit
						3a7ce45654
					
				| 
						 | 
					@ -183,66 +183,18 @@ class MeaningMap:
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_tree(self, meaning):  # return meaning all sub items
 | 
					    def get_tree(self, meaning):  # return meaning all sub items
 | 
				
			||||||
        tree = {}
 | 
					        def get_tree_dict(ms_map, vocab_size, meaning):  # return meaning all sub items
 | 
				
			||||||
        ms = self.ms_map[meaning]
 | 
					            tree = {}
 | 
				
			||||||
        for m in ms[ms > 0].tolist():
 | 
					            ms = ms_map[meaning]
 | 
				
			||||||
            tree[m] = self.get_tree(m) if m >= self.vocab_size else m
 | 
					            for m in ms[ms > 0].tolist():
 | 
				
			||||||
        return tree
 | 
					                tree[m] = get_tree_dict(ms_map, vocab_size, m) if m >= self.vocab_size else m
 | 
				
			||||||
 | 
					            return tree
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return get_tree_dict(self.ms_map, self.vocab_size, meaning)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def max_length(self):
 | 
					    def max_length(self):
 | 
				
			||||||
        return max(self.ms_len)
 | 
					        return max(self.ms_len)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def print_tree(tree):
 | 
					 | 
				
			||||||
        def get_tree_node(tree, parent=None):
 | 
					 | 
				
			||||||
            for key, value in tree.items():
 | 
					 | 
				
			||||||
                pn = Node(str(key), parent)
 | 
					 | 
				
			||||||
                if isinstance(value, dict):
 | 
					 | 
				
			||||||
                    get_tree_node(value, pn)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        assert isinstance(tree, dict)
 | 
					 | 
				
			||||||
        root = Node("root")
 | 
					 | 
				
			||||||
        get_tree_node(tree, root)
 | 
					 | 
				
			||||||
        treestr = ""
 | 
					 | 
				
			||||||
        for pre, fill, node in RenderTree(root.children[0] if len(tree) == 1 else root):
 | 
					 | 
				
			||||||
            treestr += f"{pre}{node.name}\n"
 | 
					 | 
				
			||||||
        return treestr
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_tree_indexed_str(tree, data, prefix):
 | 
					 | 
				
			||||||
        if isinstance(tree, list):
 | 
					 | 
				
			||||||
            base = ""
 | 
					 | 
				
			||||||
            qlen = 0
 | 
					 | 
				
			||||||
            for i, t in enumerate(tree):
 | 
					 | 
				
			||||||
                s, l = MeaningMap.get_tree_indexed_str(t, data[i], "")
 | 
					 | 
				
			||||||
                base += s
 | 
					 | 
				
			||||||
                qlen += l
 | 
					 | 
				
			||||||
            return (base, qlen)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            if isinstance(tree, dict):
 | 
					 | 
				
			||||||
                base = ""
 | 
					 | 
				
			||||||
                qlen = 0
 | 
					 | 
				
			||||||
                last_is_dict = None
 | 
					 | 
				
			||||||
                for key, value in tree.items():
 | 
					 | 
				
			||||||
                    new_prefix = (len(str(key)) + 2) * " " + prefix
 | 
					 | 
				
			||||||
                    dict_string, l = MeaningMap.get_tree_indexed_str(value, data[qlen:], new_prefix)
 | 
					 | 
				
			||||||
                    if dict_string:
 | 
					 | 
				
			||||||
                        base += "\n" + prefix + str(key) + ": " + dict_string
 | 
					 | 
				
			||||||
                        last_is_dict = True
 | 
					 | 
				
			||||||
                    else:
 | 
					 | 
				
			||||||
                        base += "\n" + prefix + str(data[qlen]) + "  " if last_is_dict else str(data[qlen]) + "  "
 | 
					 | 
				
			||||||
                        last_is_dict = False
 | 
					 | 
				
			||||||
                    qlen += l
 | 
					 | 
				
			||||||
                return (base, qlen)
 | 
					 | 
				
			||||||
            return (None, 1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def token_frequency(tree, freq):
 | 
					 | 
				
			||||||
        if isinstance(tree, dict):
 | 
					 | 
				
			||||||
            for key, value in tree.items():
 | 
					 | 
				
			||||||
                if key in freq:
 | 
					 | 
				
			||||||
                    freq[key] = freq[key] + 1
 | 
					 | 
				
			||||||
                else:
 | 
					 | 
				
			||||||
                    freq[key] = 1
 | 
					 | 
				
			||||||
                MeaningMap.token_frequency(value, freq)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
class MeaningDataset(Dataset):
 | 
					class MeaningDataset(Dataset):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -259,17 +211,23 @@ class MeaningDataset(Dataset):
 | 
				
			||||||
        use_cache=True,
 | 
					        use_cache=True,
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        np.random.seed(seed)
 | 
					        np.random.seed(seed)
 | 
				
			||||||
        map = MeaningMap(end, vocab_size, max_subitem, min_subitem, use_cache=use_cache)
 | 
					 | 
				
			||||||
        np.random.seed(seed)
 | 
					        np.random.seed(seed)
 | 
				
			||||||
 | 
					        self.start = start
 | 
				
			||||||
 | 
					        self.end = end
 | 
				
			||||||
 | 
					        self.vocab_size = vocab_size
 | 
				
			||||||
 | 
					        self.max_subitem = max_subitem
 | 
				
			||||||
 | 
					        self.min_subitem = min_subitem
 | 
				
			||||||
 | 
					        self.use_cache = use_cache
 | 
				
			||||||
 | 
					        self.min_seq_len = min_seq_len
 | 
				
			||||||
        print("Build MeaningDataset from MeaningMap.")
 | 
					        print("Build MeaningDataset from MeaningMap.")
 | 
				
			||||||
        self.val_mask_level = None
 | 
					        self.val_mask_level = None
 | 
				
			||||||
        self.val_mask_idx = None
 | 
					        self.val_mask_idx = None
 | 
				
			||||||
        self.tree = []
 | 
					 | 
				
			||||||
        self.seq = []
 | 
					        self.seq = []
 | 
				
			||||||
        self.level = []
 | 
					        self.level = []
 | 
				
			||||||
        self.rank_idx = []
 | 
					        self.rank_idx = []
 | 
				
			||||||
        self.rank_all = []
 | 
					        self.rank_all = []
 | 
				
			||||||
        self.seq_meaning = []
 | 
					        self.seq_meaning = []
 | 
				
			||||||
 | 
					        map = self.get_meaning_map()
 | 
				
			||||||
        self.m_height = map.ms_height
 | 
					        self.m_height = map.ms_height
 | 
				
			||||||
        self.m_weight = map.ms_weight
 | 
					        self.m_weight = map.ms_weight
 | 
				
			||||||
        if size:
 | 
					        if size:
 | 
				
			||||||
| 
						 | 
					@ -281,7 +239,6 @@ class MeaningDataset(Dataset):
 | 
				
			||||||
        for m in meanings:
 | 
					        for m in meanings:
 | 
				
			||||||
            d, l, i, a = map.get_sequence(m)
 | 
					            d, l, i, a = map.get_sequence(m)
 | 
				
			||||||
            if len(d) >= min_seq_len:
 | 
					            if len(d) >= min_seq_len:
 | 
				
			||||||
                self.tree.append({m: map.get_tree(m)})
 | 
					 | 
				
			||||||
                self.seq.append(d)
 | 
					                self.seq.append(d)
 | 
				
			||||||
                self.level.append(l)
 | 
					                self.level.append(l)
 | 
				
			||||||
                self.seq_meaning.append(m)
 | 
					                self.seq_meaning.append(m)
 | 
				
			||||||
| 
						 | 
					@ -327,6 +284,9 @@ class MeaningDataset(Dataset):
 | 
				
			||||||
    def len(self):
 | 
					    def len(self):
 | 
				
			||||||
        return len(self.seq)
 | 
					        return len(self.seq)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_meaning_map(self):
 | 
				
			||||||
 | 
					        return MeaningMap(self.end, self.vocab_size, self.max_subitem, self.min_subitem, self.use_cache)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def set_mask(self, level=None, idx=None):
 | 
					    def set_mask(self, level=None, idx=None):
 | 
				
			||||||
        if self.val_mask_level is not None and self.val_mask_idx is not None:
 | 
					        if self.val_mask_level is not None and self.val_mask_idx is not None:
 | 
				
			||||||
            assert len(self.val_mask_level) > 0, "len must > 0"
 | 
					            assert len(self.val_mask_level) > 0, "len must > 0"
 | 
				
			||||||
| 
						 | 
					@ -346,26 +306,17 @@ class MeaningDataset(Dataset):
 | 
				
			||||||
        output["input_ids"] = data
 | 
					        output["input_ids"] = data
 | 
				
			||||||
        output["labels"] = data.clone()
 | 
					        output["labels"] = data.clone()
 | 
				
			||||||
        output["token_type_ids"] = torch.zeros(data.shape)
 | 
					        output["token_type_ids"] = torch.zeros(data.shape)
 | 
				
			||||||
        output["tree"] = [self.tree[i] for i in idx_list]
 | 
					 | 
				
			||||||
        output["val_mask"] = self.get_seq_mask_tensor(idx_list)
 | 
					        output["val_mask"] = self.get_seq_mask_tensor(idx_list)
 | 
				
			||||||
        return output
 | 
					        return output
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_token(self, idx):  # must equal sequence length
 | 
					    def get_token(self, idx):  # must equal sequence length
 | 
				
			||||||
        return self.seq[idx]
 | 
					        return self.seq[idx]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_tree(self, idx):
 | 
					    def get_meaning(self, idx):
 | 
				
			||||||
        return self.tree[idx]
 | 
					        return self.seq_meaning[idx]
 | 
				
			||||||
 | 
					 | 
				
			||||||
    def print_tree(self, idx):
 | 
					 | 
				
			||||||
        tokens = self.seq[idx]
 | 
					 | 
				
			||||||
        tree = self.get_tree(idx)
 | 
					 | 
				
			||||||
        s = str(tokens) + "\n"
 | 
					 | 
				
			||||||
        s += MeaningMap.print_tree(tree)
 | 
					 | 
				
			||||||
        return s
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def copy(self, start, end):
 | 
					    def copy(self, start, end):
 | 
				
			||||||
        new = copy.deepcopy(self)
 | 
					        new = copy.deepcopy(self)
 | 
				
			||||||
        new.tree = new.tree[start:end]
 | 
					 | 
				
			||||||
        new.seq = new.seq[start:end]
 | 
					        new.seq = new.seq[start:end]
 | 
				
			||||||
        new.level = new.level[start:end]
 | 
					        new.level = new.level[start:end]
 | 
				
			||||||
        new.rank_idx = new.rank_idx[start:end]
 | 
					        new.rank_idx = new.rank_idx[start:end]
 | 
				
			||||||
| 
						 | 
					@ -380,12 +331,6 @@ class MeaningDataset(Dataset):
 | 
				
			||||||
        middle = int(l * ratio)
 | 
					        middle = int(l * ratio)
 | 
				
			||||||
        return self.copy(0, middle), self.copy(middle, l)
 | 
					        return self.copy(0, middle), self.copy(middle, l)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def token_frequency(self):
 | 
					 | 
				
			||||||
        freq = {}
 | 
					 | 
				
			||||||
        for t in self.tree:
 | 
					 | 
				
			||||||
            MeaningMap.token_frequency(t, freq)
 | 
					 | 
				
			||||||
        return freq
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_seq_mask(self, idx, level, index):
 | 
					    def get_seq_mask(self, idx, level, index):
 | 
				
			||||||
        # assert index < 15, "index must < 15"
 | 
					        # assert index < 15, "index must < 15"
 | 
				
			||||||
        # assert level < 8, "level must < 8"
 | 
					        # assert level < 8, "level must < 8"
 | 
				
			||||||
| 
						 | 
					@ -410,12 +355,12 @@ class MeaningDataset(Dataset):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class BatchGroupMeaningDataloader(Dataset):
 | 
					class BatchGroupMeaningDataloader(Dataset):
 | 
				
			||||||
    def __init__(self, dataset: MeaningDataset, batch_size, shuffle=True, drop_last=True):
 | 
					    def __init__(self, meaning_dataset: MeaningDataset, batch_size, shuffle=True, drop_last=True):
 | 
				
			||||||
        self.dataset = dataset
 | 
					        self.meaning_dataset = meaning_dataset
 | 
				
			||||||
        self.batch_size = batch_size
 | 
					        self.batch_size = batch_size
 | 
				
			||||||
        self.drop_last = drop_last
 | 
					        self.drop_last = drop_last
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        seq_len = [len(s) for s in dataset.seq]
 | 
					        seq_len = [len(s) for s in meaning_dataset.seq]
 | 
				
			||||||
        unique, counts = np.unique(seq_len, return_counts=True)
 | 
					        unique, counts = np.unique(seq_len, return_counts=True)
 | 
				
			||||||
        gl = {}
 | 
					        gl = {}
 | 
				
			||||||
        for u in unique:
 | 
					        for u in unique:
 | 
				
			||||||
| 
						 | 
					@ -451,16 +396,16 @@ class BatchGroupMeaningDataloader(Dataset):
 | 
				
			||||||
        return len(self.indexBatch)
 | 
					        return len(self.indexBatch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __getitem__(self, idx):
 | 
					    def __getitem__(self, idx):
 | 
				
			||||||
        return self.dataset.get_batch(self.indexBatch[idx])
 | 
					        return self.meaning_dataset.get_batch(self.indexBatch[idx])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_tree(self, idx):
 | 
					    def get_tree(self, idx):
 | 
				
			||||||
        return [self.dataset.get_tree(i) for i in self.indexBatch[idx]]
 | 
					        return [self.meaning_dataset.get_tree(i) for i in self.indexBatch[idx]]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def print_tree(self, idx):
 | 
					    def print_tree(self, idx):
 | 
				
			||||||
        idx_list = self.indexBatch[idx]
 | 
					        idx_list = self.indexBatch[idx]
 | 
				
			||||||
        s = "--------------------------------------------------------\n"
 | 
					        s = "--------------------------------------------------------\n"
 | 
				
			||||||
        for i in idx_list:
 | 
					        for i in idx_list:
 | 
				
			||||||
            s += self.dataset.print_tree(i)
 | 
					            s += self.meaning_dataset.print_tree(i)
 | 
				
			||||||
            s += "--------------------------------------------------------\n"
 | 
					            s += "--------------------------------------------------------\n"
 | 
				
			||||||
        return s
 | 
					        return s
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -556,7 +501,6 @@ if __name__ == "__main__":
 | 
				
			||||||
            ),
 | 
					            ),
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
    ), "False"
 | 
					    ), "False"
 | 
				
			||||||
    freq = md.token_frequency()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, min_subitem=2, use_cache=False)
 | 
					    md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, min_subitem=2, use_cache=False)
 | 
				
			||||||
    md.set_mask([0, 1], [0, -1])
 | 
					    md.set_mask([0, 1], [0, -1])
 | 
				
			||||||
| 
						 | 
					@ -572,21 +516,3 @@ if __name__ == "__main__":
 | 
				
			||||||
    print(m)
 | 
					    print(m)
 | 
				
			||||||
    ne2 = next(it)
 | 
					    ne2 = next(it)
 | 
				
			||||||
    ne3 = next(it)
 | 
					    ne3 = next(it)
 | 
				
			||||||
 | 
					 | 
				
			||||||
    # map1 = dl.get_tree(0)
 | 
					 | 
				
			||||||
    # map2 = dl.get_tree(1)
 | 
					 | 
				
			||||||
    # print(dl.print_tree(0))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # dl = DataLoader(
 | 
					 | 
				
			||||||
    #     train,
 | 
					 | 
				
			||||||
    #     num_workers=1,
 | 
					 | 
				
			||||||
    #     persistent_workers=True,
 | 
					 | 
				
			||||||
    #     shuffle=False,
 | 
					 | 
				
			||||||
    # )
 | 
					 | 
				
			||||||
    # it = iter(dl)
 | 
					 | 
				
			||||||
    # ne1 = next(it)
 | 
					 | 
				
			||||||
    # ne2 = next(it)
 | 
					 | 
				
			||||||
    # ne3 = next(it)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # for i in range(10):
 | 
					 | 
				
			||||||
    #     print(next(it)["input_ids"].numpy().tolist())
 | 
					 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,83 @@
 | 
				
			||||||
 | 
					from anytree import Node, RenderTree
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class NodeTree:
 | 
				
			||||||
 | 
					    def __init__(self, tree: dict):
 | 
				
			||||||
 | 
					        self.tree = tree
 | 
				
			||||||
 | 
					        self.node = None
 | 
				
			||||||
 | 
					        self.seq_node = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_node(self):
 | 
				
			||||||
 | 
					        def get_tree_node(tree, parent, seqlist):
 | 
				
			||||||
 | 
					            for key, value in tree.items():
 | 
				
			||||||
 | 
					                if isinstance(value, dict):
 | 
				
			||||||
 | 
					                    pn = Node(str(key), parent)
 | 
				
			||||||
 | 
					                    get_tree_node(value, pn, seqlist)
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    pn = Node("<" + str(key) + ">", parent)
 | 
				
			||||||
 | 
					                    seqlist.append(pn)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.node:
 | 
				
			||||||
 | 
					            return self.node
 | 
				
			||||||
 | 
					        assert isinstance(self.tree, dict)
 | 
				
			||||||
 | 
					        root = Node("root")
 | 
				
			||||||
 | 
					        get_tree_node(self.tree, root, self.seq_node)
 | 
				
			||||||
 | 
					        self.node = root.children[0] if len(self.tree) == 1 else root
 | 
				
			||||||
 | 
					        return self.node
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def print(self):
 | 
				
			||||||
 | 
					        treestr = ""
 | 
				
			||||||
 | 
					        for pre, fill, node in RenderTree(self.get_node()):
 | 
				
			||||||
 | 
					            treestr += f"{pre}{node.name}\n"
 | 
				
			||||||
 | 
					        print(treestr)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def print_tree(tree):
 | 
				
			||||||
 | 
					        def get_tree_node(tree, parent=None):
 | 
				
			||||||
 | 
					            for key, value in tree.items():
 | 
				
			||||||
 | 
					                pn = Node(str(key), parent)
 | 
				
			||||||
 | 
					                if isinstance(value, dict):
 | 
				
			||||||
 | 
					                    get_tree_node(value, pn)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assert isinstance(tree, dict)
 | 
				
			||||||
 | 
					        root = Node("root")
 | 
				
			||||||
 | 
					        get_tree_node(tree, root)
 | 
				
			||||||
 | 
					        treestr = ""
 | 
				
			||||||
 | 
					        for pre, fill, node in RenderTree(root.children[0] if len(tree) == 1 else root):
 | 
				
			||||||
 | 
					            treestr += f"{pre}{node.name}\n"
 | 
				
			||||||
 | 
					        return treestr
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # def get_tree_indexed_str(tree, data, prefix):
 | 
				
			||||||
 | 
					    #     if isinstance(tree, list):
 | 
				
			||||||
 | 
					    #         base = ""
 | 
				
			||||||
 | 
					    #         qlen = 0
 | 
				
			||||||
 | 
					    #         for i, t in enumerate(tree):
 | 
				
			||||||
 | 
					    #             s, l = MeaningMap.get_tree_indexed_str(t, data[i], "")
 | 
				
			||||||
 | 
					    #             base += s
 | 
				
			||||||
 | 
					    #             qlen += l
 | 
				
			||||||
 | 
					    #         return (base, qlen)
 | 
				
			||||||
 | 
					    #     else:
 | 
				
			||||||
 | 
					    #         if isinstance(tree, dict):
 | 
				
			||||||
 | 
					    #             base = ""
 | 
				
			||||||
 | 
					    #             qlen = 0
 | 
				
			||||||
 | 
					    #             last_is_dict = None
 | 
				
			||||||
 | 
					    #             for key, value in tree.items():
 | 
				
			||||||
 | 
					    #                 new_prefix = (len(str(key)) + 2) * " " + prefix
 | 
				
			||||||
 | 
					    #                 dict_string, l = MeaningMap.get_tree_indexed_str(value, data[qlen:], new_prefix)
 | 
				
			||||||
 | 
					    #                 if dict_string:
 | 
				
			||||||
 | 
					    #                     base += "\n" + prefix + str(key) + ": " + dict_string
 | 
				
			||||||
 | 
					    #                     last_is_dict = True
 | 
				
			||||||
 | 
					    #                 else:
 | 
				
			||||||
 | 
					    #                     base += "\n" + prefix + str(data[qlen]) + "  " if last_is_dict else str(data[qlen]) + "  "
 | 
				
			||||||
 | 
					    #                     last_is_dict = False
 | 
				
			||||||
 | 
					    #                 qlen += l
 | 
				
			||||||
 | 
					    #             return (base, qlen)
 | 
				
			||||||
 | 
					    #         return (None, 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # def token_frequency(tree, freq):
 | 
				
			||||||
 | 
					    #     if isinstance(tree, dict):
 | 
				
			||||||
 | 
					    #         for key, value in tree.items():
 | 
				
			||||||
 | 
					    #             if key in freq:
 | 
				
			||||||
 | 
					    #                 freq[key] = freq[key] + 1
 | 
				
			||||||
 | 
					    #             else:
 | 
				
			||||||
 | 
					    #                 freq[key] = 1
 | 
				
			||||||
 | 
					    #             MeaningMap.token_frequency(value, freq)
 | 
				
			||||||
| 
						 | 
					@ -8,6 +8,7 @@ import numpy as np
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import configuration
 | 
					import configuration
 | 
				
			||||||
import dataset.dataset as ds
 | 
					import dataset.dataset as ds
 | 
				
			||||||
 | 
					import dataset.node_tree as nt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -45,12 +46,15 @@ if __name__ == "__main__":
 | 
				
			||||||
    runner = QwenRunner(qwen.llm)
 | 
					    runner = QwenRunner(qwen.llm)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    val = ds.InitValDataset(conf).dataset
 | 
					    val = ds.InitValDataset(conf).dataset
 | 
				
			||||||
    data = val.dataset
 | 
					    md = val.meaning_dataset
 | 
				
			||||||
    item = data.get_token(0)
 | 
					 | 
				
			||||||
    print(data.print_tree(0))
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    batch = torch.tensor([item[:-1]], dtype=torch.int64)
 | 
					    map = md.get_meaning_map()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    item = md.get_token(0)
 | 
				
			||||||
 | 
					    nt.NodeTree(map.get_tree(md.get_meaning(0))).print()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    batch = torch.tensor([item[:-16]], dtype=torch.int64)
 | 
				
			||||||
    batch = batch.cuda()
 | 
					    batch = batch.cuda()
 | 
				
			||||||
    print(item)
 | 
					    # print(item)
 | 
				
			||||||
    next_token = runner.ChatToken(batch)
 | 
					    next_token = runner.ChatToken(batch)
 | 
				
			||||||
    print(next_token.detach().cpu().numpy())
 | 
					    print(next_token.detach().cpu().numpy())
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue