Refine dataset and nodetree.
This commit is contained in:
parent
a1d5fce300
commit
3c34d12fba
|
@ -6,16 +6,17 @@ from typing import Dict, Tuple
|
|||
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
|
||||
import numpy as np
|
||||
from torch.utils.data import BatchSampler
|
||||
from anytree import Node, RenderTree
|
||||
from dataset.node_tree import NodeTree
|
||||
|
||||
# import warnings
|
||||
# warnings.filterwarnings("ignore", ".*does not have many workers.*")
|
||||
|
||||
|
||||
class MeaningMap:
|
||||
def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, min_subitem=1, use_cache=True):
|
||||
def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, min_subitem=1, use_cache=True, seed=42):
|
||||
assert size > 0 and vocab_size > 0 and max_subitem > 0 and min_subitem > 0, "Invalid input"
|
||||
assert min_subitem <= max_subitem, "Invalid input"
|
||||
np.random.seed(seed)
|
||||
self.size = size
|
||||
self.vocab_size = vocab_size
|
||||
self.max_subitem = max_subitem
|
||||
|
@ -183,15 +184,21 @@ class MeaningMap:
|
|||
)
|
||||
|
||||
def get_tree(self, meaning): # return meaning all sub items
|
||||
def get_tree_dict(ms_map, vocab_size, meaning): # return meaning all sub items
|
||||
tree = {}
|
||||
def get_tree_node(ms_map, meaning, vocab_size, parent, seqlist):
|
||||
ms = ms_map[meaning]
|
||||
for m in ms[ms > 0].tolist():
|
||||
tree[m] = get_tree_dict(ms_map, vocab_size, m) if m >= self.vocab_size else m
|
||||
return tree
|
||||
for m in ms[ms >= 0].tolist():
|
||||
if m >= self.vocab_size:
|
||||
pn = NodeTree(str(m), parent)
|
||||
get_tree_node(ms_map, m, vocab_size, pn, seqlist)
|
||||
else:
|
||||
pn = NodeTree("<" + str(m) + ">", parent)
|
||||
seqlist.append(pn)
|
||||
|
||||
td = get_tree_dict(self.ms_map, self.vocab_size, meaning)
|
||||
return {meaning: td}
|
||||
root = NodeTree(str(meaning))
|
||||
seqlist = []
|
||||
get_tree_node(self.ms_map, meaning, self.vocab_size, root, seqlist)
|
||||
root.seq_node = seqlist
|
||||
return root
|
||||
|
||||
def max_length(self):
|
||||
return max(self.ms_len)
|
||||
|
@ -211,7 +218,6 @@ class MeaningDataset(Dataset):
|
|||
seed=42,
|
||||
use_cache=True,
|
||||
):
|
||||
np.random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
self.start = start
|
||||
self.end = end
|
||||
|
|
|
@ -4,7 +4,7 @@ from anytree.node.nodemixin import NodeMixin
|
|||
from anytree.node.util import _repr
|
||||
|
||||
|
||||
class Node(NodeMixin):
|
||||
class NodeTree(NodeMixin):
|
||||
def __init__(self, name, parent=None, children=None, **kwargs):
|
||||
self.__dict__.update(kwargs)
|
||||
self.name = name
|
||||
|
@ -12,93 +12,17 @@ class Node(NodeMixin):
|
|||
self.parent = parent
|
||||
if children:
|
||||
self.children = children
|
||||
self.seq_node = []
|
||||
|
||||
def __repr__(self):
|
||||
args = ["%r" % self.separator.join([""] + [str(node.name) for node in self.path])]
|
||||
return _repr(self, args=args, nameblacklist=["name"])
|
||||
|
||||
|
||||
class NodeTree:
|
||||
def __init__(self, tree: dict):
|
||||
self.tree = tree
|
||||
self.node = None
|
||||
self.seq_node = []
|
||||
self.get_node()
|
||||
|
||||
def get_node(self):
|
||||
def get_tree_node(tree, parent, seqlist):
|
||||
for key, value in tree.items():
|
||||
if isinstance(value, dict):
|
||||
pn = Node(str(key), parent)
|
||||
get_tree_node(value, pn, seqlist)
|
||||
else:
|
||||
pn = Node("<" + str(key) + ">", parent)
|
||||
seqlist.append(pn)
|
||||
|
||||
if self.node:
|
||||
return self.node
|
||||
assert isinstance(self.tree, dict)
|
||||
root = Node("root")
|
||||
get_tree_node(self.tree, root, self.seq_node)
|
||||
self.node = root.children[0] if len(self.tree) == 1 else root
|
||||
return self.node
|
||||
|
||||
def set_seq_prop(self, index, prop):
|
||||
self.seq_node[index].prop = prop
|
||||
|
||||
def print(self):
|
||||
treestr = ""
|
||||
for pre, fill, node in RenderTree(self.get_node()):
|
||||
for pre, fill, node in RenderTree(self):
|
||||
treestr += f"{pre}{node.name} {node.prop}\n"
|
||||
print(treestr)
|
||||
|
||||
def print_tree(tree):
|
||||
def get_tree_node(tree, parent=None):
|
||||
for key, value in tree.items():
|
||||
pn = Node(str(key), parent)
|
||||
if isinstance(value, dict):
|
||||
get_tree_node(value, pn)
|
||||
|
||||
assert isinstance(tree, dict)
|
||||
root = Node("root")
|
||||
get_tree_node(tree, root)
|
||||
treestr = ""
|
||||
for pre, fill, node in RenderTree(root.children[0] if len(tree) == 1 else root):
|
||||
treestr += f"{pre}{node.name}\n"
|
||||
return treestr
|
||||
|
||||
# def get_tree_indexed_str(tree, data, prefix):
|
||||
# if isinstance(tree, list):
|
||||
# base = ""
|
||||
# qlen = 0
|
||||
# for i, t in enumerate(tree):
|
||||
# s, l = MeaningMap.get_tree_indexed_str(t, data[i], "")
|
||||
# base += s
|
||||
# qlen += l
|
||||
# return (base, qlen)
|
||||
# else:
|
||||
# if isinstance(tree, dict):
|
||||
# base = ""
|
||||
# qlen = 0
|
||||
# last_is_dict = None
|
||||
# for key, value in tree.items():
|
||||
# new_prefix = (len(str(key)) + 2) * " " + prefix
|
||||
# dict_string, l = MeaningMap.get_tree_indexed_str(value, data[qlen:], new_prefix)
|
||||
# if dict_string:
|
||||
# base += "\n" + prefix + str(key) + ": " + dict_string
|
||||
# last_is_dict = True
|
||||
# else:
|
||||
# base += "\n" + prefix + str(data[qlen]) + " " if last_is_dict else str(data[qlen]) + " "
|
||||
# last_is_dict = False
|
||||
# qlen += l
|
||||
# return (base, qlen)
|
||||
# return (None, 1)
|
||||
|
||||
# def token_frequency(tree, freq):
|
||||
# if isinstance(tree, dict):
|
||||
# for key, value in tree.items():
|
||||
# if key in freq:
|
||||
# freq[key] = freq[key] + 1
|
||||
# else:
|
||||
# freq[key] = 1
|
||||
# MeaningMap.token_frequency(value, freq)
|
||||
|
|
|
@ -12,49 +12,33 @@ import dataset.node_tree as nt
|
|||
|
||||
if __name__ == "__main__":
|
||||
|
||||
conf = configuration.TrainConfig()
|
||||
config = conf.model_config
|
||||
checkpoint_path = "log/bigger/version_0/checkpoints/epoch=19-step=98720.ckpt"
|
||||
|
||||
conf.name = "bigger" # current train process name
|
||||
conf.pretrain_model_name = None # "qwen/Qwen-1_8B-Chat"
|
||||
conf.learning_rate = 0.0001
|
||||
conf.use_tril_attention_mask = None
|
||||
conf.precision = "bf16-mixed" # "precision:bf16-mixed,16-mixed,32-true"
|
||||
conf.train_batch_size = 16
|
||||
conf.val_batch_size = 4
|
||||
conf.num_proc = 8
|
||||
conf.max_epochs = 1000
|
||||
conf.strategy = "auto"
|
||||
conf.resume_from_ckpt_path = None
|
||||
conf.seed = 42
|
||||
conf.dataloader_works = 2
|
||||
|
||||
conf.dataset.meaning.val_mask_level = [0, 1, 2]
|
||||
conf.dataset.meaning.val_mask_idx = [0, 0, -1]
|
||||
|
||||
config.vocab_size = 256
|
||||
config.hidden_size = 128 # 128 1024 2048 32
|
||||
config.num_hidden_layers = 3 # 6 12 24 3
|
||||
config.num_attention_heads = 16 # 8 8 16
|
||||
|
||||
torch.manual_seed(conf.seed)
|
||||
|
||||
checkpoint_path = "log/bigger/version_1/checkpoints/epoch=23-step=24792.ckpt"
|
||||
qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
||||
qwen.eval()
|
||||
|
||||
conf = qwen.config
|
||||
torch.manual_seed(conf.seed)
|
||||
np.random.seed(conf.seed)
|
||||
runner = QwenRunner(qwen.llm)
|
||||
|
||||
# batch = torch.tensor([[41]], dtype=torch.int64)
|
||||
# print(runner.ChatTokens(batch).detach().cpu().numpy()[0])
|
||||
|
||||
val = ds.InitValDataset(conf).dataset
|
||||
md = val.meaning_dataset
|
||||
|
||||
map = md.get_meaning_map()
|
||||
|
||||
item = md.get_token(0)
|
||||
nt.NodeTree(map.get_tree(md.get_meaning(0))).print()
|
||||
|
||||
batch = torch.tensor([item[:-16]], dtype=torch.int64)
|
||||
batch = batch.cuda()
|
||||
# print(item)
|
||||
next_token = runner.ChatToken(batch)
|
||||
print(next_token.detach().cpu().numpy())
|
||||
node = map.get_tree(md.get_meaning(0))
|
||||
# node.print()
|
||||
|
||||
for i in range(1, len(item)):
|
||||
itemm = [item[:i]]
|
||||
batch = torch.tensor([item[:i]], dtype=torch.int64)
|
||||
next_token = runner.ChatTokens(batch, sample=False).detach().cpu().numpy()[0]
|
||||
if item[i] != next_token:
|
||||
node.set_seq_prop(i, "ERR_" + str(next_token))
|
||||
print(str(item[i]) + " " + str(next_token) + " ERROR")
|
||||
node.print()
|
||||
|
|
|
@ -196,16 +196,18 @@ class QwenRunner:
|
|||
# torch.backends.cuda.enable_flash_sdp(True)
|
||||
|
||||
@torch.no_grad()
|
||||
def ChatToken(self, input_ids):
|
||||
def ChatTokens(self, input_ids, sample=True):
|
||||
qwen = self.qwen
|
||||
input_ids = input_ids.to(next(qwen.parameters()).device)
|
||||
outputs, loss = self.forwardQWen(input_ids)
|
||||
next_token_scores = outputs[:, -1, :]
|
||||
|
||||
next_token_scores = self.repetition_penalty(input_ids, next_token_scores)
|
||||
next_token_scores = self.top_p(next_token_scores)
|
||||
next_tokens = self.sample(next_token_scores)
|
||||
return next_tokens
|
||||
if sample:
|
||||
return self.sample(next_token_scores)
|
||||
else:
|
||||
sorted_logits, sorted_indices = torch.sort(next_token_scores, descending=True)
|
||||
return sorted_indices[0]
|
||||
|
||||
@torch.no_grad()
|
||||
def Chat(
|
||||
|
|
|
@ -11,6 +11,7 @@ from configuration import ModelConfig, TrainConfig
|
|||
|
||||
class QwenModule(pl.LightningModule):
|
||||
def __init__(self, conf: TrainConfig = None):
|
||||
self.config = conf
|
||||
pretrained_model_dir = conf.pretrain_model_name
|
||||
learning_rate = conf.learning_rate
|
||||
mconf = conf.model_config
|
||||
|
|
12
wit/train.py
12
wit/train.py
|
@ -7,6 +7,7 @@ from logger import MLFLogger, TBLogger
|
|||
|
||||
import configuration
|
||||
import dataset.dataset as ds
|
||||
import numpy as np
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
|
@ -17,9 +18,9 @@ if __name__ == "__main__":
|
|||
conf.pretrain_model_name = None # "qwen/Qwen-1_8B-Chat"
|
||||
conf.learning_rate = 0.0001
|
||||
conf.use_tril_attention_mask = None
|
||||
conf.precision = "bf16-mixed" # "precision:bf16-mixed,16-mixed,32-true"
|
||||
conf.precision = "32-true" # "precision:bf16-mixed,16-mixed,32-true"
|
||||
conf.train_batch_size = 16
|
||||
conf.val_batch_size = 4
|
||||
conf.val_batch_size = 2
|
||||
conf.num_proc = 8
|
||||
conf.max_epochs = 1000
|
||||
conf.strategy = "auto"
|
||||
|
@ -27,15 +28,20 @@ if __name__ == "__main__":
|
|||
conf.seed = 42
|
||||
conf.dataloader_works = 2
|
||||
|
||||
conf.dataset.meaning.level_ratio = 5
|
||||
conf.dataset.meaning.level = 2
|
||||
conf.dataset.meaning.dataset_level = 5
|
||||
conf.dataset.meaning.min_subitem = 2
|
||||
conf.dataset.meaning.val_mask_level = [0, 1, 2]
|
||||
conf.dataset.meaning.val_mask_idx = [0, 0, -1]
|
||||
|
||||
config.vocab_size = 256
|
||||
config.vocab_size = 32
|
||||
config.hidden_size = 128 # 128 1024 2048 32
|
||||
config.num_hidden_layers = 3 # 6 12 24 3
|
||||
config.num_attention_heads = 16 # 8 8 16
|
||||
|
||||
torch.manual_seed(conf.seed)
|
||||
np.random.seed(conf.seed)
|
||||
qwen = QwenModule(conf)
|
||||
|
||||
train_dataloader, val_dataloader = ds.InitDataset(conf)
|
||||
|
|
Loading…
Reference in New Issue