Refine dataset and nodetree.
This commit is contained in:
		
							parent
							
								
									a1d5fce300
								
							
						
					
					
						commit
						3c34d12fba
					
				| 
						 | 
				
			
			@ -6,16 +6,17 @@ from typing import Dict, Tuple
 | 
			
		|||
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
 | 
			
		||||
import numpy as np
 | 
			
		||||
from torch.utils.data import BatchSampler
 | 
			
		||||
from anytree import Node, RenderTree
 | 
			
		||||
from dataset.node_tree import NodeTree
 | 
			
		||||
 | 
			
		||||
# import warnings
 | 
			
		||||
# warnings.filterwarnings("ignore", ".*does not have many workers.*")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MeaningMap:
 | 
			
		||||
    def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, min_subitem=1, use_cache=True):
 | 
			
		||||
    def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, min_subitem=1, use_cache=True, seed=42):
 | 
			
		||||
        assert size > 0 and vocab_size > 0 and max_subitem > 0 and min_subitem > 0, "Invalid input"
 | 
			
		||||
        assert min_subitem <= max_subitem, "Invalid input"
 | 
			
		||||
        np.random.seed(seed)
 | 
			
		||||
        self.size = size
 | 
			
		||||
        self.vocab_size = vocab_size
 | 
			
		||||
        self.max_subitem = max_subitem
 | 
			
		||||
| 
						 | 
				
			
			@ -183,15 +184,21 @@ class MeaningMap:
 | 
			
		|||
        )
 | 
			
		||||
 | 
			
		||||
    def get_tree(self, meaning):  # return meaning all sub items
 | 
			
		||||
        def get_tree_dict(ms_map, vocab_size, meaning):  # return meaning all sub items
 | 
			
		||||
            tree = {}
 | 
			
		||||
        def get_tree_node(ms_map, meaning, vocab_size, parent, seqlist):
 | 
			
		||||
            ms = ms_map[meaning]
 | 
			
		||||
            for m in ms[ms > 0].tolist():
 | 
			
		||||
                tree[m] = get_tree_dict(ms_map, vocab_size, m) if m >= self.vocab_size else m
 | 
			
		||||
            return tree
 | 
			
		||||
            for m in ms[ms >= 0].tolist():
 | 
			
		||||
                if m >= self.vocab_size:
 | 
			
		||||
                    pn = NodeTree(str(m), parent)
 | 
			
		||||
                    get_tree_node(ms_map, m, vocab_size, pn, seqlist)
 | 
			
		||||
                else:
 | 
			
		||||
                    pn = NodeTree("<" + str(m) + ">", parent)
 | 
			
		||||
                    seqlist.append(pn)
 | 
			
		||||
 | 
			
		||||
        td = get_tree_dict(self.ms_map, self.vocab_size, meaning)
 | 
			
		||||
        return {meaning: td}
 | 
			
		||||
        root = NodeTree(str(meaning))
 | 
			
		||||
        seqlist = []
 | 
			
		||||
        get_tree_node(self.ms_map, meaning, self.vocab_size, root, seqlist)
 | 
			
		||||
        root.seq_node = seqlist
 | 
			
		||||
        return root
 | 
			
		||||
 | 
			
		||||
    def max_length(self):
 | 
			
		||||
        return max(self.ms_len)
 | 
			
		||||
| 
						 | 
				
			
			@ -211,7 +218,6 @@ class MeaningDataset(Dataset):
 | 
			
		|||
        seed=42,
 | 
			
		||||
        use_cache=True,
 | 
			
		||||
    ):
 | 
			
		||||
        np.random.seed(seed)
 | 
			
		||||
        np.random.seed(seed)
 | 
			
		||||
        self.start = start
 | 
			
		||||
        self.end = end
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -4,7 +4,7 @@ from anytree.node.nodemixin import NodeMixin
 | 
			
		|||
from anytree.node.util import _repr
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Node(NodeMixin):
 | 
			
		||||
class NodeTree(NodeMixin):
 | 
			
		||||
    def __init__(self, name, parent=None, children=None, **kwargs):
 | 
			
		||||
        self.__dict__.update(kwargs)
 | 
			
		||||
        self.name = name
 | 
			
		||||
| 
						 | 
				
			
			@ -12,93 +12,17 @@ class Node(NodeMixin):
 | 
			
		|||
        self.parent = parent
 | 
			
		||||
        if children:
 | 
			
		||||
            self.children = children
 | 
			
		||||
        self.seq_node = []
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        args = ["%r" % self.separator.join([""] + [str(node.name) for node in self.path])]
 | 
			
		||||
        return _repr(self, args=args, nameblacklist=["name"])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NodeTree:
 | 
			
		||||
    def __init__(self, tree: dict):
 | 
			
		||||
        self.tree = tree
 | 
			
		||||
        self.node = None
 | 
			
		||||
        self.seq_node = []
 | 
			
		||||
        self.get_node()
 | 
			
		||||
 | 
			
		||||
    def get_node(self):
 | 
			
		||||
        def get_tree_node(tree, parent, seqlist):
 | 
			
		||||
            for key, value in tree.items():
 | 
			
		||||
                if isinstance(value, dict):
 | 
			
		||||
                    pn = Node(str(key), parent)
 | 
			
		||||
                    get_tree_node(value, pn, seqlist)
 | 
			
		||||
                else:
 | 
			
		||||
                    pn = Node("<" + str(key) + ">", parent)
 | 
			
		||||
                    seqlist.append(pn)
 | 
			
		||||
 | 
			
		||||
        if self.node:
 | 
			
		||||
            return self.node
 | 
			
		||||
        assert isinstance(self.tree, dict)
 | 
			
		||||
        root = Node("root")
 | 
			
		||||
        get_tree_node(self.tree, root, self.seq_node)
 | 
			
		||||
        self.node = root.children[0] if len(self.tree) == 1 else root
 | 
			
		||||
        return self.node
 | 
			
		||||
 | 
			
		||||
    def set_seq_prop(self, index, prop):
 | 
			
		||||
        self.seq_node[index].prop = prop
 | 
			
		||||
 | 
			
		||||
    def print(self):
 | 
			
		||||
        treestr = ""
 | 
			
		||||
        for pre, fill, node in RenderTree(self.get_node()):
 | 
			
		||||
        for pre, fill, node in RenderTree(self):
 | 
			
		||||
            treestr += f"{pre}{node.name} {node.prop}\n"
 | 
			
		||||
        print(treestr)
 | 
			
		||||
 | 
			
		||||
    def print_tree(tree):
 | 
			
		||||
        def get_tree_node(tree, parent=None):
 | 
			
		||||
            for key, value in tree.items():
 | 
			
		||||
                pn = Node(str(key), parent)
 | 
			
		||||
                if isinstance(value, dict):
 | 
			
		||||
                    get_tree_node(value, pn)
 | 
			
		||||
 | 
			
		||||
        assert isinstance(tree, dict)
 | 
			
		||||
        root = Node("root")
 | 
			
		||||
        get_tree_node(tree, root)
 | 
			
		||||
        treestr = ""
 | 
			
		||||
        for pre, fill, node in RenderTree(root.children[0] if len(tree) == 1 else root):
 | 
			
		||||
            treestr += f"{pre}{node.name}\n"
 | 
			
		||||
        return treestr
 | 
			
		||||
 | 
			
		||||
    # def get_tree_indexed_str(tree, data, prefix):
 | 
			
		||||
    #     if isinstance(tree, list):
 | 
			
		||||
    #         base = ""
 | 
			
		||||
    #         qlen = 0
 | 
			
		||||
    #         for i, t in enumerate(tree):
 | 
			
		||||
    #             s, l = MeaningMap.get_tree_indexed_str(t, data[i], "")
 | 
			
		||||
    #             base += s
 | 
			
		||||
    #             qlen += l
 | 
			
		||||
    #         return (base, qlen)
 | 
			
		||||
    #     else:
 | 
			
		||||
    #         if isinstance(tree, dict):
 | 
			
		||||
    #             base = ""
 | 
			
		||||
    #             qlen = 0
 | 
			
		||||
    #             last_is_dict = None
 | 
			
		||||
    #             for key, value in tree.items():
 | 
			
		||||
    #                 new_prefix = (len(str(key)) + 2) * " " + prefix
 | 
			
		||||
    #                 dict_string, l = MeaningMap.get_tree_indexed_str(value, data[qlen:], new_prefix)
 | 
			
		||||
    #                 if dict_string:
 | 
			
		||||
    #                     base += "\n" + prefix + str(key) + ": " + dict_string
 | 
			
		||||
    #                     last_is_dict = True
 | 
			
		||||
    #                 else:
 | 
			
		||||
    #                     base += "\n" + prefix + str(data[qlen]) + "  " if last_is_dict else str(data[qlen]) + "  "
 | 
			
		||||
    #                     last_is_dict = False
 | 
			
		||||
    #                 qlen += l
 | 
			
		||||
    #             return (base, qlen)
 | 
			
		||||
    #         return (None, 1)
 | 
			
		||||
 | 
			
		||||
    # def token_frequency(tree, freq):
 | 
			
		||||
    #     if isinstance(tree, dict):
 | 
			
		||||
    #         for key, value in tree.items():
 | 
			
		||||
    #             if key in freq:
 | 
			
		||||
    #                 freq[key] = freq[key] + 1
 | 
			
		||||
    #             else:
 | 
			
		||||
    #                 freq[key] = 1
 | 
			
		||||
    #             MeaningMap.token_frequency(value, freq)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -12,49 +12,33 @@ import dataset.node_tree as nt
 | 
			
		|||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
 | 
			
		||||
    conf = configuration.TrainConfig()
 | 
			
		||||
    config = conf.model_config
 | 
			
		||||
    checkpoint_path = "log/bigger/version_0/checkpoints/epoch=19-step=98720.ckpt"
 | 
			
		||||
 | 
			
		||||
    conf.name = "bigger"  # current train process name
 | 
			
		||||
    conf.pretrain_model_name = None  # "qwen/Qwen-1_8B-Chat"
 | 
			
		||||
    conf.learning_rate = 0.0001
 | 
			
		||||
    conf.use_tril_attention_mask = None
 | 
			
		||||
    conf.precision = "bf16-mixed"  # "precision:bf16-mixed,16-mixed,32-true"
 | 
			
		||||
    conf.train_batch_size = 16
 | 
			
		||||
    conf.val_batch_size = 4
 | 
			
		||||
    conf.num_proc = 8
 | 
			
		||||
    conf.max_epochs = 1000
 | 
			
		||||
    conf.strategy = "auto"
 | 
			
		||||
    conf.resume_from_ckpt_path = None
 | 
			
		||||
    conf.seed = 42
 | 
			
		||||
    conf.dataloader_works = 2
 | 
			
		||||
 | 
			
		||||
    conf.dataset.meaning.val_mask_level = [0, 1, 2]
 | 
			
		||||
    conf.dataset.meaning.val_mask_idx = [0, 0, -1]
 | 
			
		||||
 | 
			
		||||
    config.vocab_size = 256
 | 
			
		||||
    config.hidden_size = 128  # 128 1024 2048  32
 | 
			
		||||
    config.num_hidden_layers = 3  # 6 12 24  3
 | 
			
		||||
    config.num_attention_heads = 16  # 8 8 16
 | 
			
		||||
 | 
			
		||||
    torch.manual_seed(conf.seed)
 | 
			
		||||
 | 
			
		||||
    checkpoint_path = "log/bigger/version_1/checkpoints/epoch=23-step=24792.ckpt"
 | 
			
		||||
    qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
 | 
			
		||||
    qwen.eval()
 | 
			
		||||
 | 
			
		||||
    conf = qwen.config
 | 
			
		||||
    torch.manual_seed(conf.seed)
 | 
			
		||||
    np.random.seed(conf.seed)
 | 
			
		||||
    runner = QwenRunner(qwen.llm)
 | 
			
		||||
 | 
			
		||||
    # batch = torch.tensor([[41]], dtype=torch.int64)
 | 
			
		||||
    # print(runner.ChatTokens(batch).detach().cpu().numpy()[0])
 | 
			
		||||
 | 
			
		||||
    val = ds.InitValDataset(conf).dataset
 | 
			
		||||
    md = val.meaning_dataset
 | 
			
		||||
 | 
			
		||||
    map = md.get_meaning_map()
 | 
			
		||||
 | 
			
		||||
    item = md.get_token(0)
 | 
			
		||||
    nt.NodeTree(map.get_tree(md.get_meaning(0))).print()
 | 
			
		||||
 | 
			
		||||
    batch = torch.tensor([item[:-16]], dtype=torch.int64)
 | 
			
		||||
    batch = batch.cuda()
 | 
			
		||||
    # print(item)
 | 
			
		||||
    next_token = runner.ChatToken(batch)
 | 
			
		||||
    print(next_token.detach().cpu().numpy())
 | 
			
		||||
    node = map.get_tree(md.get_meaning(0))
 | 
			
		||||
    # node.print()
 | 
			
		||||
 | 
			
		||||
    for i in range(1, len(item)):
 | 
			
		||||
        itemm = [item[:i]]
 | 
			
		||||
        batch = torch.tensor([item[:i]], dtype=torch.int64)
 | 
			
		||||
        next_token = runner.ChatTokens(batch, sample=False).detach().cpu().numpy()[0]
 | 
			
		||||
        if item[i] != next_token:
 | 
			
		||||
            node.set_seq_prop(i, "ERR_" + str(next_token))
 | 
			
		||||
            print(str(item[i]) + "  " + str(next_token) + "  ERROR")
 | 
			
		||||
    node.print()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -196,16 +196,18 @@ class QwenRunner:
 | 
			
		|||
        # torch.backends.cuda.enable_flash_sdp(True)
 | 
			
		||||
 | 
			
		||||
    @torch.no_grad()
 | 
			
		||||
    def ChatToken(self, input_ids):
 | 
			
		||||
    def ChatTokens(self, input_ids, sample=True):
 | 
			
		||||
        qwen = self.qwen
 | 
			
		||||
        input_ids = input_ids.to(next(qwen.parameters()).device)
 | 
			
		||||
        outputs, loss = self.forwardQWen(input_ids)
 | 
			
		||||
        next_token_scores = outputs[:, -1, :]
 | 
			
		||||
 | 
			
		||||
        next_token_scores = self.repetition_penalty(input_ids, next_token_scores)
 | 
			
		||||
        next_token_scores = self.top_p(next_token_scores)
 | 
			
		||||
        next_tokens = self.sample(next_token_scores)
 | 
			
		||||
        return next_tokens
 | 
			
		||||
        if sample:
 | 
			
		||||
            return self.sample(next_token_scores)
 | 
			
		||||
        else:
 | 
			
		||||
            sorted_logits, sorted_indices = torch.sort(next_token_scores, descending=True)
 | 
			
		||||
            return sorted_indices[0]
 | 
			
		||||
 | 
			
		||||
    @torch.no_grad()
 | 
			
		||||
    def Chat(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -11,6 +11,7 @@ from configuration import ModelConfig, TrainConfig
 | 
			
		|||
 | 
			
		||||
class QwenModule(pl.LightningModule):
 | 
			
		||||
    def __init__(self, conf: TrainConfig = None):
 | 
			
		||||
        self.config = conf
 | 
			
		||||
        pretrained_model_dir = conf.pretrain_model_name
 | 
			
		||||
        learning_rate = conf.learning_rate
 | 
			
		||||
        mconf = conf.model_config
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										12
									
								
								wit/train.py
								
								
								
								
							
							
						
						
									
										12
									
								
								wit/train.py
								
								
								
								
							| 
						 | 
				
			
			@ -7,6 +7,7 @@ from logger import MLFLogger, TBLogger
 | 
			
		|||
 | 
			
		||||
import configuration
 | 
			
		||||
import dataset.dataset as ds
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -17,9 +18,9 @@ if __name__ == "__main__":
 | 
			
		|||
    conf.pretrain_model_name = None  # "qwen/Qwen-1_8B-Chat"
 | 
			
		||||
    conf.learning_rate = 0.0001
 | 
			
		||||
    conf.use_tril_attention_mask = None
 | 
			
		||||
    conf.precision = "bf16-mixed"  # "precision:bf16-mixed,16-mixed,32-true"
 | 
			
		||||
    conf.precision = "32-true"  # "precision:bf16-mixed,16-mixed,32-true"
 | 
			
		||||
    conf.train_batch_size = 16
 | 
			
		||||
    conf.val_batch_size = 4
 | 
			
		||||
    conf.val_batch_size = 2
 | 
			
		||||
    conf.num_proc = 8
 | 
			
		||||
    conf.max_epochs = 1000
 | 
			
		||||
    conf.strategy = "auto"
 | 
			
		||||
| 
						 | 
				
			
			@ -27,15 +28,20 @@ if __name__ == "__main__":
 | 
			
		|||
    conf.seed = 42
 | 
			
		||||
    conf.dataloader_works = 2
 | 
			
		||||
 | 
			
		||||
    conf.dataset.meaning.level_ratio = 5
 | 
			
		||||
    conf.dataset.meaning.level = 2
 | 
			
		||||
    conf.dataset.meaning.dataset_level = 5
 | 
			
		||||
    conf.dataset.meaning.min_subitem = 2
 | 
			
		||||
    conf.dataset.meaning.val_mask_level = [0, 1, 2]
 | 
			
		||||
    conf.dataset.meaning.val_mask_idx = [0, 0, -1]
 | 
			
		||||
 | 
			
		||||
    config.vocab_size = 256
 | 
			
		||||
    config.vocab_size = 32
 | 
			
		||||
    config.hidden_size = 128  # 128 1024 2048  32
 | 
			
		||||
    config.num_hidden_layers = 3  # 6 12 24  3
 | 
			
		||||
    config.num_attention_heads = 16  # 8 8 16
 | 
			
		||||
 | 
			
		||||
    torch.manual_seed(conf.seed)
 | 
			
		||||
    np.random.seed(conf.seed)
 | 
			
		||||
    qwen = QwenModule(conf)
 | 
			
		||||
 | 
			
		||||
    train_dataloader, val_dataloader = ds.InitDataset(conf)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue