Refine dataset and nodetree.
This commit is contained in:
		
							parent
							
								
									a1d5fce300
								
							
						
					
					
						commit
						3c34d12fba
					
				|  | @ -6,16 +6,17 @@ from typing import Dict, Tuple | |||
| from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split | ||||
| import numpy as np | ||||
| from torch.utils.data import BatchSampler | ||||
| from anytree import Node, RenderTree | ||||
| from dataset.node_tree import NodeTree | ||||
| 
 | ||||
| # import warnings | ||||
| # warnings.filterwarnings("ignore", ".*does not have many workers.*") | ||||
| 
 | ||||
| 
 | ||||
| class MeaningMap: | ||||
|     def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, min_subitem=1, use_cache=True): | ||||
|     def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, min_subitem=1, use_cache=True, seed=42): | ||||
|         assert size > 0 and vocab_size > 0 and max_subitem > 0 and min_subitem > 0, "Invalid input" | ||||
|         assert min_subitem <= max_subitem, "Invalid input" | ||||
|         np.random.seed(seed) | ||||
|         self.size = size | ||||
|         self.vocab_size = vocab_size | ||||
|         self.max_subitem = max_subitem | ||||
|  | @ -183,15 +184,21 @@ class MeaningMap: | |||
|         ) | ||||
| 
 | ||||
|     def get_tree(self, meaning):  # return meaning all sub items | ||||
|         def get_tree_dict(ms_map, vocab_size, meaning):  # return meaning all sub items | ||||
|             tree = {} | ||||
|         def get_tree_node(ms_map, meaning, vocab_size, parent, seqlist): | ||||
|             ms = ms_map[meaning] | ||||
|             for m in ms[ms > 0].tolist(): | ||||
|                 tree[m] = get_tree_dict(ms_map, vocab_size, m) if m >= self.vocab_size else m | ||||
|             return tree | ||||
|             for m in ms[ms >= 0].tolist(): | ||||
|                 if m >= self.vocab_size: | ||||
|                     pn = NodeTree(str(m), parent) | ||||
|                     get_tree_node(ms_map, m, vocab_size, pn, seqlist) | ||||
|                 else: | ||||
|                     pn = NodeTree("<" + str(m) + ">", parent) | ||||
|                     seqlist.append(pn) | ||||
| 
 | ||||
|         td = get_tree_dict(self.ms_map, self.vocab_size, meaning) | ||||
|         return {meaning: td} | ||||
|         root = NodeTree(str(meaning)) | ||||
|         seqlist = [] | ||||
|         get_tree_node(self.ms_map, meaning, self.vocab_size, root, seqlist) | ||||
|         root.seq_node = seqlist | ||||
|         return root | ||||
| 
 | ||||
|     def max_length(self): | ||||
|         return max(self.ms_len) | ||||
|  | @ -211,7 +218,6 @@ class MeaningDataset(Dataset): | |||
|         seed=42, | ||||
|         use_cache=True, | ||||
|     ): | ||||
|         np.random.seed(seed) | ||||
|         np.random.seed(seed) | ||||
|         self.start = start | ||||
|         self.end = end | ||||
|  |  | |||
|  | @ -4,7 +4,7 @@ from anytree.node.nodemixin import NodeMixin | |||
| from anytree.node.util import _repr | ||||
| 
 | ||||
| 
 | ||||
| class Node(NodeMixin): | ||||
| class NodeTree(NodeMixin): | ||||
|     def __init__(self, name, parent=None, children=None, **kwargs): | ||||
|         self.__dict__.update(kwargs) | ||||
|         self.name = name | ||||
|  | @ -12,93 +12,17 @@ class Node(NodeMixin): | |||
|         self.parent = parent | ||||
|         if children: | ||||
|             self.children = children | ||||
|         self.seq_node = [] | ||||
| 
 | ||||
|     def __repr__(self): | ||||
|         args = ["%r" % self.separator.join([""] + [str(node.name) for node in self.path])] | ||||
|         return _repr(self, args=args, nameblacklist=["name"]) | ||||
| 
 | ||||
| 
 | ||||
| class NodeTree: | ||||
|     def __init__(self, tree: dict): | ||||
|         self.tree = tree | ||||
|         self.node = None | ||||
|         self.seq_node = [] | ||||
|         self.get_node() | ||||
| 
 | ||||
|     def get_node(self): | ||||
|         def get_tree_node(tree, parent, seqlist): | ||||
|             for key, value in tree.items(): | ||||
|                 if isinstance(value, dict): | ||||
|                     pn = Node(str(key), parent) | ||||
|                     get_tree_node(value, pn, seqlist) | ||||
|                 else: | ||||
|                     pn = Node("<" + str(key) + ">", parent) | ||||
|                     seqlist.append(pn) | ||||
| 
 | ||||
|         if self.node: | ||||
|             return self.node | ||||
|         assert isinstance(self.tree, dict) | ||||
|         root = Node("root") | ||||
|         get_tree_node(self.tree, root, self.seq_node) | ||||
|         self.node = root.children[0] if len(self.tree) == 1 else root | ||||
|         return self.node | ||||
| 
 | ||||
|     def set_seq_prop(self, index, prop): | ||||
|         self.seq_node[index].prop = prop | ||||
| 
 | ||||
|     def print(self): | ||||
|         treestr = "" | ||||
|         for pre, fill, node in RenderTree(self.get_node()): | ||||
|         for pre, fill, node in RenderTree(self): | ||||
|             treestr += f"{pre}{node.name} {node.prop}\n" | ||||
|         print(treestr) | ||||
| 
 | ||||
|     def print_tree(tree): | ||||
|         def get_tree_node(tree, parent=None): | ||||
|             for key, value in tree.items(): | ||||
|                 pn = Node(str(key), parent) | ||||
|                 if isinstance(value, dict): | ||||
|                     get_tree_node(value, pn) | ||||
| 
 | ||||
|         assert isinstance(tree, dict) | ||||
|         root = Node("root") | ||||
|         get_tree_node(tree, root) | ||||
|         treestr = "" | ||||
|         for pre, fill, node in RenderTree(root.children[0] if len(tree) == 1 else root): | ||||
|             treestr += f"{pre}{node.name}\n" | ||||
|         return treestr | ||||
| 
 | ||||
|     # def get_tree_indexed_str(tree, data, prefix): | ||||
|     #     if isinstance(tree, list): | ||||
|     #         base = "" | ||||
|     #         qlen = 0 | ||||
|     #         for i, t in enumerate(tree): | ||||
|     #             s, l = MeaningMap.get_tree_indexed_str(t, data[i], "") | ||||
|     #             base += s | ||||
|     #             qlen += l | ||||
|     #         return (base, qlen) | ||||
|     #     else: | ||||
|     #         if isinstance(tree, dict): | ||||
|     #             base = "" | ||||
|     #             qlen = 0 | ||||
|     #             last_is_dict = None | ||||
|     #             for key, value in tree.items(): | ||||
|     #                 new_prefix = (len(str(key)) + 2) * " " + prefix | ||||
|     #                 dict_string, l = MeaningMap.get_tree_indexed_str(value, data[qlen:], new_prefix) | ||||
|     #                 if dict_string: | ||||
|     #                     base += "\n" + prefix + str(key) + ": " + dict_string | ||||
|     #                     last_is_dict = True | ||||
|     #                 else: | ||||
|     #                     base += "\n" + prefix + str(data[qlen]) + "  " if last_is_dict else str(data[qlen]) + "  " | ||||
|     #                     last_is_dict = False | ||||
|     #                 qlen += l | ||||
|     #             return (base, qlen) | ||||
|     #         return (None, 1) | ||||
| 
 | ||||
|     # def token_frequency(tree, freq): | ||||
|     #     if isinstance(tree, dict): | ||||
|     #         for key, value in tree.items(): | ||||
|     #             if key in freq: | ||||
|     #                 freq[key] = freq[key] + 1 | ||||
|     #             else: | ||||
|     #                 freq[key] = 1 | ||||
|     #             MeaningMap.token_frequency(value, freq) | ||||
|  |  | |||
|  | @ -12,49 +12,33 @@ import dataset.node_tree as nt | |||
| 
 | ||||
| if __name__ == "__main__": | ||||
| 
 | ||||
|     conf = configuration.TrainConfig() | ||||
|     config = conf.model_config | ||||
|     checkpoint_path = "log/bigger/version_0/checkpoints/epoch=19-step=98720.ckpt" | ||||
| 
 | ||||
|     conf.name = "bigger"  # current train process name | ||||
|     conf.pretrain_model_name = None  # "qwen/Qwen-1_8B-Chat" | ||||
|     conf.learning_rate = 0.0001 | ||||
|     conf.use_tril_attention_mask = None | ||||
|     conf.precision = "bf16-mixed"  # "precision:bf16-mixed,16-mixed,32-true" | ||||
|     conf.train_batch_size = 16 | ||||
|     conf.val_batch_size = 4 | ||||
|     conf.num_proc = 8 | ||||
|     conf.max_epochs = 1000 | ||||
|     conf.strategy = "auto" | ||||
|     conf.resume_from_ckpt_path = None | ||||
|     conf.seed = 42 | ||||
|     conf.dataloader_works = 2 | ||||
| 
 | ||||
|     conf.dataset.meaning.val_mask_level = [0, 1, 2] | ||||
|     conf.dataset.meaning.val_mask_idx = [0, 0, -1] | ||||
| 
 | ||||
|     config.vocab_size = 256 | ||||
|     config.hidden_size = 128  # 128 1024 2048  32 | ||||
|     config.num_hidden_layers = 3  # 6 12 24  3 | ||||
|     config.num_attention_heads = 16  # 8 8 16 | ||||
| 
 | ||||
|     torch.manual_seed(conf.seed) | ||||
| 
 | ||||
|     checkpoint_path = "log/bigger/version_1/checkpoints/epoch=23-step=24792.ckpt" | ||||
|     qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path) | ||||
|     qwen.eval() | ||||
| 
 | ||||
|     conf = qwen.config | ||||
|     torch.manual_seed(conf.seed) | ||||
|     np.random.seed(conf.seed) | ||||
|     runner = QwenRunner(qwen.llm) | ||||
| 
 | ||||
|     # batch = torch.tensor([[41]], dtype=torch.int64) | ||||
|     # print(runner.ChatTokens(batch).detach().cpu().numpy()[0]) | ||||
| 
 | ||||
|     val = ds.InitValDataset(conf).dataset | ||||
|     md = val.meaning_dataset | ||||
| 
 | ||||
|     map = md.get_meaning_map() | ||||
| 
 | ||||
|     item = md.get_token(0) | ||||
|     nt.NodeTree(map.get_tree(md.get_meaning(0))).print() | ||||
| 
 | ||||
|     batch = torch.tensor([item[:-16]], dtype=torch.int64) | ||||
|     batch = batch.cuda() | ||||
|     # print(item) | ||||
|     next_token = runner.ChatToken(batch) | ||||
|     print(next_token.detach().cpu().numpy()) | ||||
|     node = map.get_tree(md.get_meaning(0)) | ||||
|     # node.print() | ||||
| 
 | ||||
|     for i in range(1, len(item)): | ||||
|         itemm = [item[:i]] | ||||
|         batch = torch.tensor([item[:i]], dtype=torch.int64) | ||||
|         next_token = runner.ChatTokens(batch, sample=False).detach().cpu().numpy()[0] | ||||
|         if item[i] != next_token: | ||||
|             node.set_seq_prop(i, "ERR_" + str(next_token)) | ||||
|             print(str(item[i]) + "  " + str(next_token) + "  ERROR") | ||||
|     node.print() | ||||
|  |  | |||
|  | @ -196,16 +196,18 @@ class QwenRunner: | |||
|         # torch.backends.cuda.enable_flash_sdp(True) | ||||
| 
 | ||||
|     @torch.no_grad() | ||||
|     def ChatToken(self, input_ids): | ||||
|     def ChatTokens(self, input_ids, sample=True): | ||||
|         qwen = self.qwen | ||||
|         input_ids = input_ids.to(next(qwen.parameters()).device) | ||||
|         outputs, loss = self.forwardQWen(input_ids) | ||||
|         next_token_scores = outputs[:, -1, :] | ||||
| 
 | ||||
|         next_token_scores = self.repetition_penalty(input_ids, next_token_scores) | ||||
|         next_token_scores = self.top_p(next_token_scores) | ||||
|         next_tokens = self.sample(next_token_scores) | ||||
|         return next_tokens | ||||
|         if sample: | ||||
|             return self.sample(next_token_scores) | ||||
|         else: | ||||
|             sorted_logits, sorted_indices = torch.sort(next_token_scores, descending=True) | ||||
|             return sorted_indices[0] | ||||
| 
 | ||||
|     @torch.no_grad() | ||||
|     def Chat( | ||||
|  |  | |||
|  | @ -11,6 +11,7 @@ from configuration import ModelConfig, TrainConfig | |||
| 
 | ||||
| class QwenModule(pl.LightningModule): | ||||
|     def __init__(self, conf: TrainConfig = None): | ||||
|         self.config = conf | ||||
|         pretrained_model_dir = conf.pretrain_model_name | ||||
|         learning_rate = conf.learning_rate | ||||
|         mconf = conf.model_config | ||||
|  |  | |||
							
								
								
									
										12
									
								
								wit/train.py
								
								
								
								
							
							
						
						
									
										12
									
								
								wit/train.py
								
								
								
								
							|  | @ -7,6 +7,7 @@ from logger import MLFLogger, TBLogger | |||
| 
 | ||||
| import configuration | ||||
| import dataset.dataset as ds | ||||
| import numpy as np | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
| 
 | ||||
|  | @ -17,9 +18,9 @@ if __name__ == "__main__": | |||
|     conf.pretrain_model_name = None  # "qwen/Qwen-1_8B-Chat" | ||||
|     conf.learning_rate = 0.0001 | ||||
|     conf.use_tril_attention_mask = None | ||||
|     conf.precision = "bf16-mixed"  # "precision:bf16-mixed,16-mixed,32-true" | ||||
|     conf.precision = "32-true"  # "precision:bf16-mixed,16-mixed,32-true" | ||||
|     conf.train_batch_size = 16 | ||||
|     conf.val_batch_size = 4 | ||||
|     conf.val_batch_size = 2 | ||||
|     conf.num_proc = 8 | ||||
|     conf.max_epochs = 1000 | ||||
|     conf.strategy = "auto" | ||||
|  | @ -27,15 +28,20 @@ if __name__ == "__main__": | |||
|     conf.seed = 42 | ||||
|     conf.dataloader_works = 2 | ||||
| 
 | ||||
|     conf.dataset.meaning.level_ratio = 5 | ||||
|     conf.dataset.meaning.level = 2 | ||||
|     conf.dataset.meaning.dataset_level = 5 | ||||
|     conf.dataset.meaning.min_subitem = 2 | ||||
|     conf.dataset.meaning.val_mask_level = [0, 1, 2] | ||||
|     conf.dataset.meaning.val_mask_idx = [0, 0, -1] | ||||
| 
 | ||||
|     config.vocab_size = 256 | ||||
|     config.vocab_size = 32 | ||||
|     config.hidden_size = 128  # 128 1024 2048  32 | ||||
|     config.num_hidden_layers = 3  # 6 12 24  3 | ||||
|     config.num_attention_heads = 16  # 8 8 16 | ||||
| 
 | ||||
|     torch.manual_seed(conf.seed) | ||||
|     np.random.seed(conf.seed) | ||||
|     qwen = QwenModule(conf) | ||||
| 
 | ||||
|     train_dataloader, val_dataloader = ds.InitDataset(conf) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue