Refine dataset and nodetree.

This commit is contained in:
Colin 2025-02-24 21:38:31 +08:00
parent a1d5fce300
commit 3c34d12fba
6 changed files with 53 additions and 130 deletions

View File

@ -6,16 +6,17 @@ from typing import Dict, Tuple
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
import numpy as np
from torch.utils.data import BatchSampler
from anytree import Node, RenderTree
from dataset.node_tree import NodeTree
# import warnings
# warnings.filterwarnings("ignore", ".*does not have many workers.*")
class MeaningMap:
def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, min_subitem=1, use_cache=True):
def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, min_subitem=1, use_cache=True, seed=42):
assert size > 0 and vocab_size > 0 and max_subitem > 0 and min_subitem > 0, "Invalid input"
assert min_subitem <= max_subitem, "Invalid input"
np.random.seed(seed)
self.size = size
self.vocab_size = vocab_size
self.max_subitem = max_subitem
@ -183,15 +184,21 @@ class MeaningMap:
)
def get_tree(self, meaning): # return meaning all sub items
def get_tree_dict(ms_map, vocab_size, meaning): # return meaning all sub items
tree = {}
def get_tree_node(ms_map, meaning, vocab_size, parent, seqlist):
ms = ms_map[meaning]
for m in ms[ms > 0].tolist():
tree[m] = get_tree_dict(ms_map, vocab_size, m) if m >= self.vocab_size else m
return tree
for m in ms[ms >= 0].tolist():
if m >= self.vocab_size:
pn = NodeTree(str(m), parent)
get_tree_node(ms_map, m, vocab_size, pn, seqlist)
else:
pn = NodeTree("<" + str(m) + ">", parent)
seqlist.append(pn)
td = get_tree_dict(self.ms_map, self.vocab_size, meaning)
return {meaning: td}
root = NodeTree(str(meaning))
seqlist = []
get_tree_node(self.ms_map, meaning, self.vocab_size, root, seqlist)
root.seq_node = seqlist
return root
def max_length(self):
return max(self.ms_len)
@ -211,7 +218,6 @@ class MeaningDataset(Dataset):
seed=42,
use_cache=True,
):
np.random.seed(seed)
np.random.seed(seed)
self.start = start
self.end = end

View File

@ -4,7 +4,7 @@ from anytree.node.nodemixin import NodeMixin
from anytree.node.util import _repr
class Node(NodeMixin):
class NodeTree(NodeMixin):
def __init__(self, name, parent=None, children=None, **kwargs):
self.__dict__.update(kwargs)
self.name = name
@ -12,93 +12,17 @@ class Node(NodeMixin):
self.parent = parent
if children:
self.children = children
self.seq_node = []
def __repr__(self):
args = ["%r" % self.separator.join([""] + [str(node.name) for node in self.path])]
return _repr(self, args=args, nameblacklist=["name"])
class NodeTree:
def __init__(self, tree: dict):
self.tree = tree
self.node = None
self.seq_node = []
self.get_node()
def get_node(self):
def get_tree_node(tree, parent, seqlist):
for key, value in tree.items():
if isinstance(value, dict):
pn = Node(str(key), parent)
get_tree_node(value, pn, seqlist)
else:
pn = Node("<" + str(key) + ">", parent)
seqlist.append(pn)
if self.node:
return self.node
assert isinstance(self.tree, dict)
root = Node("root")
get_tree_node(self.tree, root, self.seq_node)
self.node = root.children[0] if len(self.tree) == 1 else root
return self.node
def set_seq_prop(self, index, prop):
self.seq_node[index].prop = prop
def print(self):
treestr = ""
for pre, fill, node in RenderTree(self.get_node()):
for pre, fill, node in RenderTree(self):
treestr += f"{pre}{node.name} {node.prop}\n"
print(treestr)
def print_tree(tree):
def get_tree_node(tree, parent=None):
for key, value in tree.items():
pn = Node(str(key), parent)
if isinstance(value, dict):
get_tree_node(value, pn)
assert isinstance(tree, dict)
root = Node("root")
get_tree_node(tree, root)
treestr = ""
for pre, fill, node in RenderTree(root.children[0] if len(tree) == 1 else root):
treestr += f"{pre}{node.name}\n"
return treestr
# def get_tree_indexed_str(tree, data, prefix):
# if isinstance(tree, list):
# base = ""
# qlen = 0
# for i, t in enumerate(tree):
# s, l = MeaningMap.get_tree_indexed_str(t, data[i], "")
# base += s
# qlen += l
# return (base, qlen)
# else:
# if isinstance(tree, dict):
# base = ""
# qlen = 0
# last_is_dict = None
# for key, value in tree.items():
# new_prefix = (len(str(key)) + 2) * " " + prefix
# dict_string, l = MeaningMap.get_tree_indexed_str(value, data[qlen:], new_prefix)
# if dict_string:
# base += "\n" + prefix + str(key) + ": " + dict_string
# last_is_dict = True
# else:
# base += "\n" + prefix + str(data[qlen]) + " " if last_is_dict else str(data[qlen]) + " "
# last_is_dict = False
# qlen += l
# return (base, qlen)
# return (None, 1)
# def token_frequency(tree, freq):
# if isinstance(tree, dict):
# for key, value in tree.items():
# if key in freq:
# freq[key] = freq[key] + 1
# else:
# freq[key] = 1
# MeaningMap.token_frequency(value, freq)

View File

@ -12,49 +12,33 @@ import dataset.node_tree as nt
if __name__ == "__main__":
conf = configuration.TrainConfig()
config = conf.model_config
checkpoint_path = "log/bigger/version_0/checkpoints/epoch=19-step=98720.ckpt"
conf.name = "bigger" # current train process name
conf.pretrain_model_name = None # "qwen/Qwen-1_8B-Chat"
conf.learning_rate = 0.0001
conf.use_tril_attention_mask = None
conf.precision = "bf16-mixed" # "precision:bf16-mixed,16-mixed,32-true"
conf.train_batch_size = 16
conf.val_batch_size = 4
conf.num_proc = 8
conf.max_epochs = 1000
conf.strategy = "auto"
conf.resume_from_ckpt_path = None
conf.seed = 42
conf.dataloader_works = 2
conf.dataset.meaning.val_mask_level = [0, 1, 2]
conf.dataset.meaning.val_mask_idx = [0, 0, -1]
config.vocab_size = 256
config.hidden_size = 128 # 128 1024 2048 32
config.num_hidden_layers = 3 # 6 12 24 3
config.num_attention_heads = 16 # 8 8 16
torch.manual_seed(conf.seed)
checkpoint_path = "log/bigger/version_1/checkpoints/epoch=23-step=24792.ckpt"
qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
qwen.eval()
conf = qwen.config
torch.manual_seed(conf.seed)
np.random.seed(conf.seed)
runner = QwenRunner(qwen.llm)
# batch = torch.tensor([[41]], dtype=torch.int64)
# print(runner.ChatTokens(batch).detach().cpu().numpy()[0])
val = ds.InitValDataset(conf).dataset
md = val.meaning_dataset
map = md.get_meaning_map()
item = md.get_token(0)
nt.NodeTree(map.get_tree(md.get_meaning(0))).print()
batch = torch.tensor([item[:-16]], dtype=torch.int64)
batch = batch.cuda()
# print(item)
next_token = runner.ChatToken(batch)
print(next_token.detach().cpu().numpy())
node = map.get_tree(md.get_meaning(0))
# node.print()
for i in range(1, len(item)):
itemm = [item[:i]]
batch = torch.tensor([item[:i]], dtype=torch.int64)
next_token = runner.ChatTokens(batch, sample=False).detach().cpu().numpy()[0]
if item[i] != next_token:
node.set_seq_prop(i, "ERR_" + str(next_token))
print(str(item[i]) + " " + str(next_token) + " ERROR")
node.print()

View File

@ -196,16 +196,18 @@ class QwenRunner:
# torch.backends.cuda.enable_flash_sdp(True)
@torch.no_grad()
def ChatToken(self, input_ids):
def ChatTokens(self, input_ids, sample=True):
qwen = self.qwen
input_ids = input_ids.to(next(qwen.parameters()).device)
outputs, loss = self.forwardQWen(input_ids)
next_token_scores = outputs[:, -1, :]
next_token_scores = self.repetition_penalty(input_ids, next_token_scores)
next_token_scores = self.top_p(next_token_scores)
next_tokens = self.sample(next_token_scores)
return next_tokens
if sample:
return self.sample(next_token_scores)
else:
sorted_logits, sorted_indices = torch.sort(next_token_scores, descending=True)
return sorted_indices[0]
@torch.no_grad()
def Chat(

View File

@ -11,6 +11,7 @@ from configuration import ModelConfig, TrainConfig
class QwenModule(pl.LightningModule):
def __init__(self, conf: TrainConfig = None):
self.config = conf
pretrained_model_dir = conf.pretrain_model_name
learning_rate = conf.learning_rate
mconf = conf.model_config

View File

@ -7,6 +7,7 @@ from logger import MLFLogger, TBLogger
import configuration
import dataset.dataset as ds
import numpy as np
if __name__ == "__main__":
@ -17,9 +18,9 @@ if __name__ == "__main__":
conf.pretrain_model_name = None # "qwen/Qwen-1_8B-Chat"
conf.learning_rate = 0.0001
conf.use_tril_attention_mask = None
conf.precision = "bf16-mixed" # "precision:bf16-mixed,16-mixed,32-true"
conf.precision = "32-true" # "precision:bf16-mixed,16-mixed,32-true"
conf.train_batch_size = 16
conf.val_batch_size = 4
conf.val_batch_size = 2
conf.num_proc = 8
conf.max_epochs = 1000
conf.strategy = "auto"
@ -27,15 +28,20 @@ if __name__ == "__main__":
conf.seed = 42
conf.dataloader_works = 2
conf.dataset.meaning.level_ratio = 5
conf.dataset.meaning.level = 2
conf.dataset.meaning.dataset_level = 5
conf.dataset.meaning.min_subitem = 2
conf.dataset.meaning.val_mask_level = [0, 1, 2]
conf.dataset.meaning.val_mask_idx = [0, 0, -1]
config.vocab_size = 256
config.vocab_size = 32
config.hidden_size = 128 # 128 1024 2048 32
config.num_hidden_layers = 3 # 6 12 24 3
config.num_attention_heads = 16 # 8 8 16
torch.manual_seed(conf.seed)
np.random.seed(conf.seed)
qwen = QwenModule(conf)
train_dataloader, val_dataloader = ds.InitDataset(conf)