Refine dataset and nodetree.
This commit is contained in:
parent
a1d5fce300
commit
3c34d12fba
|
@ -6,16 +6,17 @@ from typing import Dict, Tuple
|
||||||
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
|
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.data import BatchSampler
|
from torch.utils.data import BatchSampler
|
||||||
from anytree import Node, RenderTree
|
from dataset.node_tree import NodeTree
|
||||||
|
|
||||||
# import warnings
|
# import warnings
|
||||||
# warnings.filterwarnings("ignore", ".*does not have many workers.*")
|
# warnings.filterwarnings("ignore", ".*does not have many workers.*")
|
||||||
|
|
||||||
|
|
||||||
class MeaningMap:
|
class MeaningMap:
|
||||||
def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, min_subitem=1, use_cache=True):
|
def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, min_subitem=1, use_cache=True, seed=42):
|
||||||
assert size > 0 and vocab_size > 0 and max_subitem > 0 and min_subitem > 0, "Invalid input"
|
assert size > 0 and vocab_size > 0 and max_subitem > 0 and min_subitem > 0, "Invalid input"
|
||||||
assert min_subitem <= max_subitem, "Invalid input"
|
assert min_subitem <= max_subitem, "Invalid input"
|
||||||
|
np.random.seed(seed)
|
||||||
self.size = size
|
self.size = size
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.max_subitem = max_subitem
|
self.max_subitem = max_subitem
|
||||||
|
@ -183,15 +184,21 @@ class MeaningMap:
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_tree(self, meaning): # return meaning all sub items
|
def get_tree(self, meaning): # return meaning all sub items
|
||||||
def get_tree_dict(ms_map, vocab_size, meaning): # return meaning all sub items
|
def get_tree_node(ms_map, meaning, vocab_size, parent, seqlist):
|
||||||
tree = {}
|
|
||||||
ms = ms_map[meaning]
|
ms = ms_map[meaning]
|
||||||
for m in ms[ms > 0].tolist():
|
for m in ms[ms >= 0].tolist():
|
||||||
tree[m] = get_tree_dict(ms_map, vocab_size, m) if m >= self.vocab_size else m
|
if m >= self.vocab_size:
|
||||||
return tree
|
pn = NodeTree(str(m), parent)
|
||||||
|
get_tree_node(ms_map, m, vocab_size, pn, seqlist)
|
||||||
|
else:
|
||||||
|
pn = NodeTree("<" + str(m) + ">", parent)
|
||||||
|
seqlist.append(pn)
|
||||||
|
|
||||||
td = get_tree_dict(self.ms_map, self.vocab_size, meaning)
|
root = NodeTree(str(meaning))
|
||||||
return {meaning: td}
|
seqlist = []
|
||||||
|
get_tree_node(self.ms_map, meaning, self.vocab_size, root, seqlist)
|
||||||
|
root.seq_node = seqlist
|
||||||
|
return root
|
||||||
|
|
||||||
def max_length(self):
|
def max_length(self):
|
||||||
return max(self.ms_len)
|
return max(self.ms_len)
|
||||||
|
@ -211,7 +218,6 @@ class MeaningDataset(Dataset):
|
||||||
seed=42,
|
seed=42,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
):
|
):
|
||||||
np.random.seed(seed)
|
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
self.start = start
|
self.start = start
|
||||||
self.end = end
|
self.end = end
|
||||||
|
|
|
@ -4,7 +4,7 @@ from anytree.node.nodemixin import NodeMixin
|
||||||
from anytree.node.util import _repr
|
from anytree.node.util import _repr
|
||||||
|
|
||||||
|
|
||||||
class Node(NodeMixin):
|
class NodeTree(NodeMixin):
|
||||||
def __init__(self, name, parent=None, children=None, **kwargs):
|
def __init__(self, name, parent=None, children=None, **kwargs):
|
||||||
self.__dict__.update(kwargs)
|
self.__dict__.update(kwargs)
|
||||||
self.name = name
|
self.name = name
|
||||||
|
@ -12,93 +12,17 @@ class Node(NodeMixin):
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
if children:
|
if children:
|
||||||
self.children = children
|
self.children = children
|
||||||
|
self.seq_node = []
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
args = ["%r" % self.separator.join([""] + [str(node.name) for node in self.path])]
|
args = ["%r" % self.separator.join([""] + [str(node.name) for node in self.path])]
|
||||||
return _repr(self, args=args, nameblacklist=["name"])
|
return _repr(self, args=args, nameblacklist=["name"])
|
||||||
|
|
||||||
|
|
||||||
class NodeTree:
|
|
||||||
def __init__(self, tree: dict):
|
|
||||||
self.tree = tree
|
|
||||||
self.node = None
|
|
||||||
self.seq_node = []
|
|
||||||
self.get_node()
|
|
||||||
|
|
||||||
def get_node(self):
|
|
||||||
def get_tree_node(tree, parent, seqlist):
|
|
||||||
for key, value in tree.items():
|
|
||||||
if isinstance(value, dict):
|
|
||||||
pn = Node(str(key), parent)
|
|
||||||
get_tree_node(value, pn, seqlist)
|
|
||||||
else:
|
|
||||||
pn = Node("<" + str(key) + ">", parent)
|
|
||||||
seqlist.append(pn)
|
|
||||||
|
|
||||||
if self.node:
|
|
||||||
return self.node
|
|
||||||
assert isinstance(self.tree, dict)
|
|
||||||
root = Node("root")
|
|
||||||
get_tree_node(self.tree, root, self.seq_node)
|
|
||||||
self.node = root.children[0] if len(self.tree) == 1 else root
|
|
||||||
return self.node
|
|
||||||
|
|
||||||
def set_seq_prop(self, index, prop):
|
def set_seq_prop(self, index, prop):
|
||||||
self.seq_node[index].prop = prop
|
self.seq_node[index].prop = prop
|
||||||
|
|
||||||
def print(self):
|
def print(self):
|
||||||
treestr = ""
|
treestr = ""
|
||||||
for pre, fill, node in RenderTree(self.get_node()):
|
for pre, fill, node in RenderTree(self):
|
||||||
treestr += f"{pre}{node.name} {node.prop}\n"
|
treestr += f"{pre}{node.name} {node.prop}\n"
|
||||||
print(treestr)
|
print(treestr)
|
||||||
|
|
||||||
def print_tree(tree):
|
|
||||||
def get_tree_node(tree, parent=None):
|
|
||||||
for key, value in tree.items():
|
|
||||||
pn = Node(str(key), parent)
|
|
||||||
if isinstance(value, dict):
|
|
||||||
get_tree_node(value, pn)
|
|
||||||
|
|
||||||
assert isinstance(tree, dict)
|
|
||||||
root = Node("root")
|
|
||||||
get_tree_node(tree, root)
|
|
||||||
treestr = ""
|
|
||||||
for pre, fill, node in RenderTree(root.children[0] if len(tree) == 1 else root):
|
|
||||||
treestr += f"{pre}{node.name}\n"
|
|
||||||
return treestr
|
|
||||||
|
|
||||||
# def get_tree_indexed_str(tree, data, prefix):
|
|
||||||
# if isinstance(tree, list):
|
|
||||||
# base = ""
|
|
||||||
# qlen = 0
|
|
||||||
# for i, t in enumerate(tree):
|
|
||||||
# s, l = MeaningMap.get_tree_indexed_str(t, data[i], "")
|
|
||||||
# base += s
|
|
||||||
# qlen += l
|
|
||||||
# return (base, qlen)
|
|
||||||
# else:
|
|
||||||
# if isinstance(tree, dict):
|
|
||||||
# base = ""
|
|
||||||
# qlen = 0
|
|
||||||
# last_is_dict = None
|
|
||||||
# for key, value in tree.items():
|
|
||||||
# new_prefix = (len(str(key)) + 2) * " " + prefix
|
|
||||||
# dict_string, l = MeaningMap.get_tree_indexed_str(value, data[qlen:], new_prefix)
|
|
||||||
# if dict_string:
|
|
||||||
# base += "\n" + prefix + str(key) + ": " + dict_string
|
|
||||||
# last_is_dict = True
|
|
||||||
# else:
|
|
||||||
# base += "\n" + prefix + str(data[qlen]) + " " if last_is_dict else str(data[qlen]) + " "
|
|
||||||
# last_is_dict = False
|
|
||||||
# qlen += l
|
|
||||||
# return (base, qlen)
|
|
||||||
# return (None, 1)
|
|
||||||
|
|
||||||
# def token_frequency(tree, freq):
|
|
||||||
# if isinstance(tree, dict):
|
|
||||||
# for key, value in tree.items():
|
|
||||||
# if key in freq:
|
|
||||||
# freq[key] = freq[key] + 1
|
|
||||||
# else:
|
|
||||||
# freq[key] = 1
|
|
||||||
# MeaningMap.token_frequency(value, freq)
|
|
||||||
|
|
|
@ -12,49 +12,33 @@ import dataset.node_tree as nt
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
conf = configuration.TrainConfig()
|
checkpoint_path = "log/bigger/version_0/checkpoints/epoch=19-step=98720.ckpt"
|
||||||
config = conf.model_config
|
|
||||||
|
|
||||||
conf.name = "bigger" # current train process name
|
|
||||||
conf.pretrain_model_name = None # "qwen/Qwen-1_8B-Chat"
|
|
||||||
conf.learning_rate = 0.0001
|
|
||||||
conf.use_tril_attention_mask = None
|
|
||||||
conf.precision = "bf16-mixed" # "precision:bf16-mixed,16-mixed,32-true"
|
|
||||||
conf.train_batch_size = 16
|
|
||||||
conf.val_batch_size = 4
|
|
||||||
conf.num_proc = 8
|
|
||||||
conf.max_epochs = 1000
|
|
||||||
conf.strategy = "auto"
|
|
||||||
conf.resume_from_ckpt_path = None
|
|
||||||
conf.seed = 42
|
|
||||||
conf.dataloader_works = 2
|
|
||||||
|
|
||||||
conf.dataset.meaning.val_mask_level = [0, 1, 2]
|
|
||||||
conf.dataset.meaning.val_mask_idx = [0, 0, -1]
|
|
||||||
|
|
||||||
config.vocab_size = 256
|
|
||||||
config.hidden_size = 128 # 128 1024 2048 32
|
|
||||||
config.num_hidden_layers = 3 # 6 12 24 3
|
|
||||||
config.num_attention_heads = 16 # 8 8 16
|
|
||||||
|
|
||||||
torch.manual_seed(conf.seed)
|
|
||||||
|
|
||||||
checkpoint_path = "log/bigger/version_1/checkpoints/epoch=23-step=24792.ckpt"
|
|
||||||
qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
||||||
qwen.eval()
|
qwen.eval()
|
||||||
|
conf = qwen.config
|
||||||
|
torch.manual_seed(conf.seed)
|
||||||
|
np.random.seed(conf.seed)
|
||||||
runner = QwenRunner(qwen.llm)
|
runner = QwenRunner(qwen.llm)
|
||||||
|
|
||||||
|
# batch = torch.tensor([[41]], dtype=torch.int64)
|
||||||
|
# print(runner.ChatTokens(batch).detach().cpu().numpy()[0])
|
||||||
|
|
||||||
val = ds.InitValDataset(conf).dataset
|
val = ds.InitValDataset(conf).dataset
|
||||||
md = val.meaning_dataset
|
md = val.meaning_dataset
|
||||||
|
|
||||||
map = md.get_meaning_map()
|
map = md.get_meaning_map()
|
||||||
|
|
||||||
item = md.get_token(0)
|
item = md.get_token(0)
|
||||||
nt.NodeTree(map.get_tree(md.get_meaning(0))).print()
|
|
||||||
|
|
||||||
batch = torch.tensor([item[:-16]], dtype=torch.int64)
|
node = map.get_tree(md.get_meaning(0))
|
||||||
batch = batch.cuda()
|
# node.print()
|
||||||
# print(item)
|
|
||||||
next_token = runner.ChatToken(batch)
|
for i in range(1, len(item)):
|
||||||
print(next_token.detach().cpu().numpy())
|
itemm = [item[:i]]
|
||||||
|
batch = torch.tensor([item[:i]], dtype=torch.int64)
|
||||||
|
next_token = runner.ChatTokens(batch, sample=False).detach().cpu().numpy()[0]
|
||||||
|
if item[i] != next_token:
|
||||||
|
node.set_seq_prop(i, "ERR_" + str(next_token))
|
||||||
|
print(str(item[i]) + " " + str(next_token) + " ERROR")
|
||||||
|
node.print()
|
||||||
|
|
|
@ -196,16 +196,18 @@ class QwenRunner:
|
||||||
# torch.backends.cuda.enable_flash_sdp(True)
|
# torch.backends.cuda.enable_flash_sdp(True)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def ChatToken(self, input_ids):
|
def ChatTokens(self, input_ids, sample=True):
|
||||||
qwen = self.qwen
|
qwen = self.qwen
|
||||||
input_ids = input_ids.to(next(qwen.parameters()).device)
|
input_ids = input_ids.to(next(qwen.parameters()).device)
|
||||||
outputs, loss = self.forwardQWen(input_ids)
|
outputs, loss = self.forwardQWen(input_ids)
|
||||||
next_token_scores = outputs[:, -1, :]
|
next_token_scores = outputs[:, -1, :]
|
||||||
|
|
||||||
next_token_scores = self.repetition_penalty(input_ids, next_token_scores)
|
next_token_scores = self.repetition_penalty(input_ids, next_token_scores)
|
||||||
next_token_scores = self.top_p(next_token_scores)
|
next_token_scores = self.top_p(next_token_scores)
|
||||||
next_tokens = self.sample(next_token_scores)
|
if sample:
|
||||||
return next_tokens
|
return self.sample(next_token_scores)
|
||||||
|
else:
|
||||||
|
sorted_logits, sorted_indices = torch.sort(next_token_scores, descending=True)
|
||||||
|
return sorted_indices[0]
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def Chat(
|
def Chat(
|
||||||
|
|
|
@ -11,6 +11,7 @@ from configuration import ModelConfig, TrainConfig
|
||||||
|
|
||||||
class QwenModule(pl.LightningModule):
|
class QwenModule(pl.LightningModule):
|
||||||
def __init__(self, conf: TrainConfig = None):
|
def __init__(self, conf: TrainConfig = None):
|
||||||
|
self.config = conf
|
||||||
pretrained_model_dir = conf.pretrain_model_name
|
pretrained_model_dir = conf.pretrain_model_name
|
||||||
learning_rate = conf.learning_rate
|
learning_rate = conf.learning_rate
|
||||||
mconf = conf.model_config
|
mconf = conf.model_config
|
||||||
|
|
12
wit/train.py
12
wit/train.py
|
@ -7,6 +7,7 @@ from logger import MLFLogger, TBLogger
|
||||||
|
|
||||||
import configuration
|
import configuration
|
||||||
import dataset.dataset as ds
|
import dataset.dataset as ds
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
@ -17,9 +18,9 @@ if __name__ == "__main__":
|
||||||
conf.pretrain_model_name = None # "qwen/Qwen-1_8B-Chat"
|
conf.pretrain_model_name = None # "qwen/Qwen-1_8B-Chat"
|
||||||
conf.learning_rate = 0.0001
|
conf.learning_rate = 0.0001
|
||||||
conf.use_tril_attention_mask = None
|
conf.use_tril_attention_mask = None
|
||||||
conf.precision = "bf16-mixed" # "precision:bf16-mixed,16-mixed,32-true"
|
conf.precision = "32-true" # "precision:bf16-mixed,16-mixed,32-true"
|
||||||
conf.train_batch_size = 16
|
conf.train_batch_size = 16
|
||||||
conf.val_batch_size = 4
|
conf.val_batch_size = 2
|
||||||
conf.num_proc = 8
|
conf.num_proc = 8
|
||||||
conf.max_epochs = 1000
|
conf.max_epochs = 1000
|
||||||
conf.strategy = "auto"
|
conf.strategy = "auto"
|
||||||
|
@ -27,15 +28,20 @@ if __name__ == "__main__":
|
||||||
conf.seed = 42
|
conf.seed = 42
|
||||||
conf.dataloader_works = 2
|
conf.dataloader_works = 2
|
||||||
|
|
||||||
|
conf.dataset.meaning.level_ratio = 5
|
||||||
|
conf.dataset.meaning.level = 2
|
||||||
|
conf.dataset.meaning.dataset_level = 5
|
||||||
|
conf.dataset.meaning.min_subitem = 2
|
||||||
conf.dataset.meaning.val_mask_level = [0, 1, 2]
|
conf.dataset.meaning.val_mask_level = [0, 1, 2]
|
||||||
conf.dataset.meaning.val_mask_idx = [0, 0, -1]
|
conf.dataset.meaning.val_mask_idx = [0, 0, -1]
|
||||||
|
|
||||||
config.vocab_size = 256
|
config.vocab_size = 32
|
||||||
config.hidden_size = 128 # 128 1024 2048 32
|
config.hidden_size = 128 # 128 1024 2048 32
|
||||||
config.num_hidden_layers = 3 # 6 12 24 3
|
config.num_hidden_layers = 3 # 6 12 24 3
|
||||||
config.num_attention_heads = 16 # 8 8 16
|
config.num_attention_heads = 16 # 8 8 16
|
||||||
|
|
||||||
torch.manual_seed(conf.seed)
|
torch.manual_seed(conf.seed)
|
||||||
|
np.random.seed(conf.seed)
|
||||||
qwen = QwenModule(conf)
|
qwen = QwenModule(conf)
|
||||||
|
|
||||||
train_dataloader, val_dataloader = ds.InitDataset(conf)
|
train_dataloader, val_dataloader = ds.InitDataset(conf)
|
||||||
|
|
Loading…
Reference in New Issue