diff --git a/wit/dataset/meaning_dataset.py b/wit/dataset/meaning_dataset.py index 6c0dd18..4f220dc 100644 --- a/wit/dataset/meaning_dataset.py +++ b/wit/dataset/meaning_dataset.py @@ -6,16 +6,17 @@ from typing import Dict, Tuple from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split import numpy as np from torch.utils.data import BatchSampler -from anytree import Node, RenderTree +from dataset.node_tree import NodeTree # import warnings # warnings.filterwarnings("ignore", ".*does not have many workers.*") class MeaningMap: - def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, min_subitem=1, use_cache=True): + def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, min_subitem=1, use_cache=True, seed=42): assert size > 0 and vocab_size > 0 and max_subitem > 0 and min_subitem > 0, "Invalid input" assert min_subitem <= max_subitem, "Invalid input" + np.random.seed(seed) self.size = size self.vocab_size = vocab_size self.max_subitem = max_subitem @@ -183,15 +184,21 @@ class MeaningMap: ) def get_tree(self, meaning): # return meaning all sub items - def get_tree_dict(ms_map, vocab_size, meaning): # return meaning all sub items - tree = {} + def get_tree_node(ms_map, meaning, vocab_size, parent, seqlist): ms = ms_map[meaning] - for m in ms[ms > 0].tolist(): - tree[m] = get_tree_dict(ms_map, vocab_size, m) if m >= self.vocab_size else m - return tree + for m in ms[ms >= 0].tolist(): + if m >= self.vocab_size: + pn = NodeTree(str(m), parent) + get_tree_node(ms_map, m, vocab_size, pn, seqlist) + else: + pn = NodeTree("<" + str(m) + ">", parent) + seqlist.append(pn) - td = get_tree_dict(self.ms_map, self.vocab_size, meaning) - return {meaning: td} + root = NodeTree(str(meaning)) + seqlist = [] + get_tree_node(self.ms_map, meaning, self.vocab_size, root, seqlist) + root.seq_node = seqlist + return root def max_length(self): return max(self.ms_len) @@ -211,7 +218,6 @@ class MeaningDataset(Dataset): seed=42, use_cache=True, ): - np.random.seed(seed) np.random.seed(seed) self.start = start self.end = end diff --git a/wit/dataset/node_tree.py b/wit/dataset/node_tree.py index ae20ed7..a148699 100644 --- a/wit/dataset/node_tree.py +++ b/wit/dataset/node_tree.py @@ -4,7 +4,7 @@ from anytree.node.nodemixin import NodeMixin from anytree.node.util import _repr -class Node(NodeMixin): +class NodeTree(NodeMixin): def __init__(self, name, parent=None, children=None, **kwargs): self.__dict__.update(kwargs) self.name = name @@ -12,93 +12,17 @@ class Node(NodeMixin): self.parent = parent if children: self.children = children + self.seq_node = [] def __repr__(self): args = ["%r" % self.separator.join([""] + [str(node.name) for node in self.path])] return _repr(self, args=args, nameblacklist=["name"]) - -class NodeTree: - def __init__(self, tree: dict): - self.tree = tree - self.node = None - self.seq_node = [] - self.get_node() - - def get_node(self): - def get_tree_node(tree, parent, seqlist): - for key, value in tree.items(): - if isinstance(value, dict): - pn = Node(str(key), parent) - get_tree_node(value, pn, seqlist) - else: - pn = Node("<" + str(key) + ">", parent) - seqlist.append(pn) - - if self.node: - return self.node - assert isinstance(self.tree, dict) - root = Node("root") - get_tree_node(self.tree, root, self.seq_node) - self.node = root.children[0] if len(self.tree) == 1 else root - return self.node - def set_seq_prop(self, index, prop): self.seq_node[index].prop = prop def print(self): treestr = "" - for pre, fill, node in RenderTree(self.get_node()): + for pre, fill, node in RenderTree(self): treestr += f"{pre}{node.name} {node.prop}\n" print(treestr) - - def print_tree(tree): - def get_tree_node(tree, parent=None): - for key, value in tree.items(): - pn = Node(str(key), parent) - if isinstance(value, dict): - get_tree_node(value, pn) - - assert isinstance(tree, dict) - root = Node("root") - get_tree_node(tree, root) - treestr = "" - for pre, fill, node in RenderTree(root.children[0] if len(tree) == 1 else root): - treestr += f"{pre}{node.name}\n" - return treestr - - # def get_tree_indexed_str(tree, data, prefix): - # if isinstance(tree, list): - # base = "" - # qlen = 0 - # for i, t in enumerate(tree): - # s, l = MeaningMap.get_tree_indexed_str(t, data[i], "") - # base += s - # qlen += l - # return (base, qlen) - # else: - # if isinstance(tree, dict): - # base = "" - # qlen = 0 - # last_is_dict = None - # for key, value in tree.items(): - # new_prefix = (len(str(key)) + 2) * " " + prefix - # dict_string, l = MeaningMap.get_tree_indexed_str(value, data[qlen:], new_prefix) - # if dict_string: - # base += "\n" + prefix + str(key) + ": " + dict_string - # last_is_dict = True - # else: - # base += "\n" + prefix + str(data[qlen]) + " " if last_is_dict else str(data[qlen]) + " " - # last_is_dict = False - # qlen += l - # return (base, qlen) - # return (None, 1) - - # def token_frequency(tree, freq): - # if isinstance(tree, dict): - # for key, value in tree.items(): - # if key in freq: - # freq[key] = freq[key] + 1 - # else: - # freq[key] = 1 - # MeaningMap.token_frequency(value, freq) diff --git a/wit/inference.py b/wit/inference.py index 9005807..5d2a946 100644 --- a/wit/inference.py +++ b/wit/inference.py @@ -12,49 +12,33 @@ import dataset.node_tree as nt if __name__ == "__main__": - conf = configuration.TrainConfig() - config = conf.model_config + checkpoint_path = "log/bigger/version_0/checkpoints/epoch=19-step=98720.ckpt" - conf.name = "bigger" # current train process name - conf.pretrain_model_name = None # "qwen/Qwen-1_8B-Chat" - conf.learning_rate = 0.0001 - conf.use_tril_attention_mask = None - conf.precision = "bf16-mixed" # "precision:bf16-mixed,16-mixed,32-true" - conf.train_batch_size = 16 - conf.val_batch_size = 4 - conf.num_proc = 8 - conf.max_epochs = 1000 - conf.strategy = "auto" - conf.resume_from_ckpt_path = None - conf.seed = 42 - conf.dataloader_works = 2 - - conf.dataset.meaning.val_mask_level = [0, 1, 2] - conf.dataset.meaning.val_mask_idx = [0, 0, -1] - - config.vocab_size = 256 - config.hidden_size = 128 # 128 1024 2048 32 - config.num_hidden_layers = 3 # 6 12 24 3 - config.num_attention_heads = 16 # 8 8 16 - - torch.manual_seed(conf.seed) - - checkpoint_path = "log/bigger/version_1/checkpoints/epoch=23-step=24792.ckpt" qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path) qwen.eval() - + conf = qwen.config + torch.manual_seed(conf.seed) + np.random.seed(conf.seed) runner = QwenRunner(qwen.llm) + # batch = torch.tensor([[41]], dtype=torch.int64) + # print(runner.ChatTokens(batch).detach().cpu().numpy()[0]) + val = ds.InitValDataset(conf).dataset md = val.meaning_dataset map = md.get_meaning_map() item = md.get_token(0) - nt.NodeTree(map.get_tree(md.get_meaning(0))).print() - batch = torch.tensor([item[:-16]], dtype=torch.int64) - batch = batch.cuda() - # print(item) - next_token = runner.ChatToken(batch) - print(next_token.detach().cpu().numpy()) + node = map.get_tree(md.get_meaning(0)) + # node.print() + + for i in range(1, len(item)): + itemm = [item[:i]] + batch = torch.tensor([item[:i]], dtype=torch.int64) + next_token = runner.ChatTokens(batch, sample=False).detach().cpu().numpy()[0] + if item[i] != next_token: + node.set_seq_prop(i, "ERR_" + str(next_token)) + print(str(item[i]) + " " + str(next_token) + " ERROR") + node.print() diff --git a/wit/model/modeling_wit.py b/wit/model/modeling_wit.py index fac0b7c..0fc20fd 100644 --- a/wit/model/modeling_wit.py +++ b/wit/model/modeling_wit.py @@ -196,16 +196,18 @@ class QwenRunner: # torch.backends.cuda.enable_flash_sdp(True) @torch.no_grad() - def ChatToken(self, input_ids): + def ChatTokens(self, input_ids, sample=True): qwen = self.qwen input_ids = input_ids.to(next(qwen.parameters()).device) outputs, loss = self.forwardQWen(input_ids) next_token_scores = outputs[:, -1, :] - next_token_scores = self.repetition_penalty(input_ids, next_token_scores) next_token_scores = self.top_p(next_token_scores) - next_tokens = self.sample(next_token_scores) - return next_tokens + if sample: + return self.sample(next_token_scores) + else: + sorted_logits, sorted_indices = torch.sort(next_token_scores, descending=True) + return sorted_indices[0] @torch.no_grad() def Chat( diff --git a/wit/model/qwen_module.py b/wit/model/qwen_module.py index ad5bf19..7ab35e8 100644 --- a/wit/model/qwen_module.py +++ b/wit/model/qwen_module.py @@ -11,6 +11,7 @@ from configuration import ModelConfig, TrainConfig class QwenModule(pl.LightningModule): def __init__(self, conf: TrainConfig = None): + self.config = conf pretrained_model_dir = conf.pretrain_model_name learning_rate = conf.learning_rate mconf = conf.model_config diff --git a/wit/train.py b/wit/train.py index 2701639..2b39d68 100644 --- a/wit/train.py +++ b/wit/train.py @@ -7,6 +7,7 @@ from logger import MLFLogger, TBLogger import configuration import dataset.dataset as ds +import numpy as np if __name__ == "__main__": @@ -17,9 +18,9 @@ if __name__ == "__main__": conf.pretrain_model_name = None # "qwen/Qwen-1_8B-Chat" conf.learning_rate = 0.0001 conf.use_tril_attention_mask = None - conf.precision = "bf16-mixed" # "precision:bf16-mixed,16-mixed,32-true" + conf.precision = "32-true" # "precision:bf16-mixed,16-mixed,32-true" conf.train_batch_size = 16 - conf.val_batch_size = 4 + conf.val_batch_size = 2 conf.num_proc = 8 conf.max_epochs = 1000 conf.strategy = "auto" @@ -27,15 +28,20 @@ if __name__ == "__main__": conf.seed = 42 conf.dataloader_works = 2 + conf.dataset.meaning.level_ratio = 5 + conf.dataset.meaning.level = 2 + conf.dataset.meaning.dataset_level = 5 + conf.dataset.meaning.min_subitem = 2 conf.dataset.meaning.val_mask_level = [0, 1, 2] conf.dataset.meaning.val_mask_idx = [0, 0, -1] - config.vocab_size = 256 + config.vocab_size = 32 config.hidden_size = 128 # 128 1024 2048 32 config.num_hidden_layers = 3 # 6 12 24 3 config.num_attention_heads = 16 # 8 8 16 torch.manual_seed(conf.seed) + np.random.seed(conf.seed) qwen = QwenModule(conf) train_dataloader, val_dataloader = ds.InitDataset(conf)