Add meaning dataset get_relation_table support and paint to qk image.

This commit is contained in:
Colin 2025-06-25 20:13:48 +08:00
parent d8539b6b2b
commit 927c98e823
12 changed files with 176 additions and 69 deletions

34
finetune/embedding.py Normal file
View File

@ -0,0 +1,34 @@
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
# Sentences we want sentence embeddings for
sentences = ['This is an example sentence', 'Each sentence is converted']
# Load model from HuggingFace Hub
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
# Tokenize sentences
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
# Compute token embeddings
with torch.no_grad():
model_output = model(**encoded_input)
# Perform pooling
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
# Normalize embeddings
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
print("Sentence embeddings:")
print(sentence_embeddings)
print(sentence_embeddings.cpu().numpy())

View File

@ -9,7 +9,14 @@ import os
from pathlib import Path from pathlib import Path
def toTensor(tensor):
if not torch.is_tensor(tensor):
tensor = torch.tensor(tensor)
return tensor
def DumpTensorToImage(tensor, name, forceSquare=False, scale=1.0, Contrast=None, GridValue=None, Value2Log=False): def DumpTensorToImage(tensor, name, forceSquare=False, scale=1.0, Contrast=None, GridValue=None, Value2Log=False):
tensor = toTensor(tensor)
if len(tensor.shape) != 2 and len(tensor.shape) != 1 and len(tensor.shape) != 3: if len(tensor.shape) != 2 and len(tensor.shape) != 1 and len(tensor.shape) != 3:
raise ("Error input dims") raise ("Error input dims")
if ("." not in name) or (name.split(".")[-1] not in {"jpg", "png", "bmp"}): if ("." not in name) or (name.split(".")[-1] not in {"jpg", "png", "bmp"}):
@ -20,6 +27,7 @@ def DumpTensorToImage(tensor, name, forceSquare=False, scale=1.0, Contrast=None,
if len(tensor.shape) == 3: if len(tensor.shape) == 3:
channel = tensor.shape[0] channel = tensor.shape[0]
x = math.ceil((channel) ** 0.5) x = math.ceil((channel) ** 0.5)
y = math.ceil((x * x) / channel)
calc = tensor.reshape((channel, tensor.shape[1] * tensor.shape[2])) calc = tensor.reshape((channel, tensor.shape[1] * tensor.shape[2]))
if not Contrast: if not Contrast:
tensormax = calc.max(1)[0] tensormax = calc.max(1)[0]
@ -33,11 +41,11 @@ def DumpTensorToImage(tensor, name, forceSquare=False, scale=1.0, Contrast=None,
calc = calc.reshape((channel, tensor.shape[1], tensor.shape[2])) calc = calc.reshape((channel, tensor.shape[1], tensor.shape[2]))
if not GridValue: if not GridValue:
GridValue = 128.0 GridValue = 128.0
calc = F.pad(calc, (0, 0, 0, 0, 0, x * x - channel), mode="constant", value=GridValue) calc = F.pad(calc, (0, 0, 0, 0, 0, x * y - channel), mode="constant", value=GridValue)
calc = calc.reshape((x, x, tensor.shape[1], tensor.shape[2])) calc = calc.reshape((y, x, tensor.shape[1], tensor.shape[2]))
calc = F.pad(calc, (0, 1, 0, 1, 0, 0), mode="constant", value=GridValue) calc = F.pad(calc, (0, 1, 0, 1, 0, 0), mode="constant", value=GridValue)
tensor = calc.permute((0, 2, 1, 3)) tensor = calc.permute((0, 2, 1, 3))
tensor = tensor.reshape((x * tensor.shape[1], x * tensor.shape[3])) tensor = tensor.reshape((y * tensor.shape[1], x * tensor.shape[3]))
DumpTensorToImage(tensor, name, forceSquare=False, scale=scale, Contrast=[0.0, 255.0], GridValue=GridValue) DumpTensorToImage(tensor, name, forceSquare=False, scale=scale, Contrast=[0.0, 255.0], GridValue=GridValue)
return return
@ -78,6 +86,7 @@ def DumpTensorToImage(tensor, name, forceSquare=False, scale=1.0, Contrast=None,
def DumpTensorToLog(tensor, name="log"): def DumpTensorToLog(tensor, name="log"):
tensor = toTensor(tensor)
tensor_mean = torch.mean(tensor).cpu().detach().numpy() tensor_mean = torch.mean(tensor).cpu().detach().numpy()
tensor_abs_mean = torch.mean(torch.abs(tensor)).cpu().detach().numpy() tensor_abs_mean = torch.mean(torch.abs(tensor)).cpu().detach().numpy()
tensor_range = (torch.max(tensor) - torch.min(tensor)).cpu().detach().numpy() tensor_range = (torch.max(tensor) - torch.min(tensor)).cpu().detach().numpy()
@ -92,6 +101,7 @@ def DumpTensorToLog(tensor, name="log"):
def DumpTensorToFile(tensor, name="tensor.pt"): def DumpTensorToFile(tensor, name="tensor.pt"):
tensor = toTensor(tensor)
torch.save(tensor.cpu(), name) torch.save(tensor.cpu(), name)

View File

@ -38,10 +38,10 @@ class ModelConfig:
class MeaningDatasetConfig: class MeaningDatasetConfig:
def __init__(self): def __init__(self):
self.level_ratio = 5 self.start = 10000
self.level = 5 self.size = 4000
self.dataset_level = 3
self.min_subitem = 2 self.min_subitem = 2
self.max_subitem = 10
self.val_mask_level = None self.val_mask_level = None
self.val_mask_idx = None self.val_mask_idx = None

View File

@ -31,12 +31,12 @@ def InitDataset(config):
if config.dataset.name == "meaning": if config.dataset.name == "meaning":
c = config.dataset.meaning c = config.dataset.meaning
vocab = config.model_config.vocab_size vocab = config.model_config.vocab_size
start = vocab * (c.level_ratio**c.level) start = c.start
size = vocab * int((c.level_ratio**c.dataset_level)) size = c.size
path = "./data/" path = "./data/"
trainfile = path + f"MeaningDataset_train_v{size}_s{start}_s{size}_lr{c.level_ratio}_ms{c.min_subitem}.pt" trainfile = path + f"MeaningDataset_train_v{size}_s{start}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}.pt"
valfile = path + f"MeaningDataset_val_v{size}_s{start}_s{size}_lr{c.level_ratio}_ms{c.min_subitem}.pt" valfile = path + f"MeaningDataset_val_v{size}_s{start}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}.pt"
if not os.path.exists(path): if not os.path.exists(path):
os.mkdir(path) os.mkdir(path)
if os.path.exists(trainfile) and os.path.exists(valfile): if os.path.exists(trainfile) and os.path.exists(valfile):
@ -48,7 +48,7 @@ def InitDataset(config):
val_dataset.set_mask(c.val_mask_level, c.val_mask_idx) val_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
print(f"INFO: Load dataset end") print(f"INFO: Load dataset end")
else: else:
raw_dataset = MeaningDataset(start, start + size, vocab, None, c.level_ratio, c.min_subitem) raw_dataset = MeaningDataset(start, start + size, vocab, None, c.max_subitem, c.min_subitem)
raw_dataset.set_mask(c.val_mask_level, c.val_mask_idx) raw_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
train_dataset, val_dataset = raw_dataset.split(0.9) train_dataset, val_dataset = raw_dataset.split(0.9)
torch.save(train_dataset, trainfile) torch.save(train_dataset, trainfile)
@ -80,11 +80,11 @@ def InitValDataset(config):
if config.dataset.name == "meaning": if config.dataset.name == "meaning":
c = config.dataset.meaning c = config.dataset.meaning
vocab = config.model_config.vocab_size vocab = config.model_config.vocab_size
start = vocab * (c.level_ratio**c.level) start = c.start
size = vocab * int((c.level_ratio**c.dataset_level)) size = c.size
path = "./data/" path = "./data/"
valfile = path + f"MeaningDataset_val_v{size}_s{start}_s{size}_lr{c.level_ratio}_ms{c.min_subitem}.pt" valfile = path + f"MeaningDataset_val_v{size}_s{start}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}.pt"
if not os.path.exists(path): if not os.path.exists(path):
os.mkdir(path) os.mkdir(path)
if os.path.exists(valfile): if os.path.exists(valfile):
@ -93,7 +93,7 @@ def InitValDataset(config):
val_dataset.set_mask(c.val_mask_level, c.val_mask_idx) val_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
print(f"INFO: Load dataset end") print(f"INFO: Load dataset end")
else: else:
raw_dataset = MeaningDataset(start, start + size, vocab, None, c.level_ratio, c.min_subitem) raw_dataset = MeaningDataset(start, start + size, vocab, None, c.max_subitem, c.min_subitem)
raw_dataset.set_mask(c.val_mask_level, c.val_mask_idx) raw_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
train_dataset, val_dataset = raw_dataset.split(0.9) train_dataset, val_dataset = raw_dataset.split(0.9)
torch.save(val_dataset, valfile) torch.save(val_dataset, valfile)

View File

@ -105,12 +105,13 @@ class MeaningMap:
index = index + 1 index = index + 1
for i in range(self.vocab_size, size): for i in range(self.vocab_size, size):
m = map[i] m = map[i] # 当前meaning的拆分的分支
m = m[m >= 0] # donot cut off the map such as [0] m = m[m >= 0] # donot cut off the map such as [0]
m_len = len(m) m_len = len(m) # 当前meaning的拆分的分支个数
m_list = m.tolist() m_list = m.tolist()
assert m_list, "map list can not be empty list" assert m_list, "map list can not be empty list"
# 获取每个子meaning的start和end并且生成序列组合成当前meaning完整的叶index<vocab_size)
idx = np.concatenate([np.arange(ms_start[m], ms_end[m]) for m in m_list]) idx = np.concatenate([np.arange(ms_start[m], ms_end[m]) for m in m_list])
idxidx = np.concatenate( idxidx = np.concatenate(
[np.ones(l, dtype=np.uint32) * i for i, l in enumerate(ms_end[m_list] - ms_start[m_list])] [np.ones(l, dtype=np.uint32) * i for i, l in enumerate(ms_end[m_list] - ms_start[m_list])]
@ -118,16 +119,16 @@ class MeaningMap:
len_ma = len(idx) len_ma = len(idx)
end = index + len_ma end = index + len_ma
if ms_data.size < end: if ms_data.size < end: # 超过存储数据结构的大小扩展一个datastep容量
ms_data = np.concatenate([ms_data, np.zeros((datastep), dtype=np.int32)]) ms_data = np.concatenate([ms_data, np.zeros((datastep), dtype=np.int32)])
ms_level = np.concatenate([ms_level, np.zeros((datastep), dtype=np.uint32)]) ms_level = np.concatenate([ms_level, np.zeros((datastep), dtype=np.uint32)])
ms_rank_idx = np.concatenate([ms_rank_idx, np.zeros((datastep), dtype=np.uint32)]) ms_rank_idx = np.concatenate([ms_rank_idx, np.zeros((datastep), dtype=np.uint32)])
ms_rank_all = np.concatenate([ms_rank_all, np.zeros((datastep), dtype=np.uint32)]) ms_rank_all = np.concatenate([ms_rank_all, np.zeros((datastep), dtype=np.uint32)])
ms_data[index:end] = ms_data[idx] ms_data[index:end] = ms_data[idx] # 拼接当前meaning的所有token到data数据结构里面
ms_level[index:end] = ms_level[idx] + 1 ms_level[index:end] = ms_level[idx] + 1 # 处理level
ms_rank_idx[index:end] = (ms_rank_idx[idx] * 16 + idxidx).astype(np.uint32) ms_rank_idx[index:end] = (ms_rank_idx[idx] * 16 + idxidx).astype(np.uint32) # 处理rank_idx
ms_rank_all[index:end] = (ms_rank_all[idx] * 16 + m_len).astype(np.uint32) ms_rank_all[index:end] = (ms_rank_all[idx] * 16 + m_len).astype(np.uint32) # 处理rank_all
ms_start[i] = index ms_start[i] = index
ms_end[i] = end ms_end[i] = end
@ -200,21 +201,47 @@ class MeaningMap:
root.seq_node = seqlist root.seq_node = seqlist
return root return root
def get_tree(self, meaning): # 返回每个token相对于上一个token的level变化
def get_tree_list(ms_map, meaning, mlist): # 返回两个list分别表示 common -> current -> common 两个变化的level距离
def get_level_change(self, meaning):
def level_change(ms_map, meaning, current_to_common, common_to_current):
ms = ms_map[meaning] ms = ms_map[meaning]
mlist.append(meaning)
mlist.append(-255) # level down marker
for m in ms[ms >= 0].tolist(): for m in ms[ms >= 0].tolist():
if m >= self.vocab_size: if m >= self.vocab_size:
get_tree_list(ms_map, m, mlist) common_to_current[-1] = common_to_current[-1] + 1
level_change(ms_map, m, current_to_common, common_to_current)
else: else:
mlist.append(m) current_to_common.append(0)
mlist.append(-1) # level up marker common_to_current.append(0)
current_to_common[-2] = current_to_common[-2] + 1
meaninglist = [] common_to_current = []
get_tree_list(self.ms_map, meaning, meaninglist) common_to_current.append(1)
return meaninglist current_to_common = []
current_to_common.append(0)
level_change(self.ms_map, meaning, current_to_common, common_to_current)
current_to_common = current_to_common[:-1]
common_to_current = common_to_current[:-1]
return current_to_common, common_to_current
# 根据meaning的层级结构范围一个二位的数组表示所有token跟前面token是否有关系
def get_relation_table(self, meaning):
current_to_common, common_to_current = self.get_level_change(meaning)
width = len(current_to_common)
relation = np.zeros((width, width), dtype=int)
relation[0, 0] = 1
for i in range(1, width, 1):
if i == width - 2:
print(1)
ori = current_to_common[i] - common_to_current[i]
start_index = width
for s in range(i - 1, -1, -1):
if ori < 0:
break
ori = ori - common_to_current[s] + current_to_common[s]
start_index = s
relation[i, start_index : i + 1] = 1.0
return relation
def max_length(self): def max_length(self):
return max(self.ms_len) return max(self.ms_len)
@ -430,6 +457,7 @@ class BatchGroupMeaningDataloader(Dataset):
self, batch_size=1, num_workers=num_workers, collate_fn=BatchGroupMeaningDataloader.detection_collate self, batch_size=1, num_workers=num_workers, collate_fn=BatchGroupMeaningDataloader.detection_collate
) )
if __name__ == "__main__": if __name__ == "__main__":
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=False) md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=False)

View File

@ -12,13 +12,20 @@ meaning数据集是一个模仿自然语言以及抽象表达的数据集。
6. level表示当前token相对于root meaning的距离 6. level表示当前token相对于root meaning的距离
7. rank 7. rank
8. rank_idx表示当前token在不同层的排序编号每4位表示在一层里面的编号低4位表示最低层级的rank_idx高位无用的位用1填充 8. rank_idx表示当前token在不同层的排序编号每4位表示在一层里面的编号低4位表示最低层级的rank_idx高位无用的位用1填充
9. rank_all表示当前token在不同层的分子个数每4位表示在一层里面的编号低4位表示最低层级的rank_all高位无用的位用1填充 9. rank_all表示当前token所在的不同层的总的分支个数每4位表示在一层里面的个数低4位表示最低层级的rank_all高位无用的位用1填充
10. tree用于存储每个meaning的拆解的数据使用字典表达一个树形结构 10. tree用于存储每个meaning的拆解的数据使用字典表达一个树形结构
11. get_seq_mask返回一个sequence每个token在对应level是不是对应的index,level=0:最底层index=-1:最后一个index=0:第一个 11. get_seq_mask返回一个sequence每个token在对应level是不是对应的index,level=0:最底层index=-1:最后一个index=0:第一个
12. meaning_height 当前meaning的总高度 12. meaning_height 当前meaning的总高度
13. meaning_weight 当前meaning的总宽度 13. meaning_weight 当前meaning的总宽度
14. val_mask_level val_mask_idx: 表示用于训练的token的mask,val_mask_level=[0, 1, 2] val_mask_idx=[0, 0, -1]表示只有是第0层第0个,而且是第1层第0个,第2层最后一个的token,才参与训练 14. val_mask_level val_mask_idx: 表示用于训练的token的mask,val_mask_level=[0, 1, 2] val_mask_idx=[0, 0, -1]表示只有是第0层第0个,而且是第1层第0个,第2层最后一个的token,才参与训练
## code
```
vocab = config.model_config.vocab_size
start 数据集的样本开始meaning
size 数据集的样本个数
```
``` ```
vocab_size = 256 meaning = 115200 vocab_size = 256 meaning = 115200
@ -33,11 +40,13 @@ vocab_size = 256 meaning = 115200
/ \ / \ / \ / \ / \ / \ / \ / \
176 11 255 129 129 99 211 111 176 11 255 129 129 99 211 111
sequence = 123 42 32 176 11 255 129 245 233 129 99 23 211 111 93 176 sequence = 123 42 32 176 11 255 129 245 233 129 99 23 211 111 93 176
level = 3 3 2 4 4 4 4 2 2 4 4 3 4 4 3 3 level = 3 3 2 4 4 4 4 2 2 4 4 3 4 4 3 3
idx at 0 = 0 1 1 0 1 0 1 0 1 0 1 2 0 1 0 1
idx at 1 = 0 0 0 0 0 1 1 1 1 0 0 0 0 0 2 2
idx 0 1 1 0 1 16 17 16 17 0 1 2 0 1 32 33
rank_idx = 0 1 1 0 1 16 17 16 17 0 1 2 0 1 32 33
idx at0 = 0 1 1 0 1 0 1 0 1 0 1 2 0 1 0 1
idx at1 = 0 0 0 0 0 1 1 1 1 0 0 0 0 0 2 2
rank_all =
``` ```

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.8 KiB

View File

@ -10,4 +10,11 @@
## 不同模型深度对结果的影响 ## 不同模型深度对结果的影响
6层相对于3层没有提升的原因可能是数据集太小3层已经能完全拟合 6层相对于3层没有提升的原因可能是数据集太小3层已经能完全拟合
![alt text](model_level_number.png) ![alt text](model_level_number.png)
## qk图解释
1. key[10] = 1000.0
2. 每一行数据像素表示一个新的token和前面所有token的关系
![alt text](q@k_seq_47_layer_0.png)

View File

@ -8,7 +8,9 @@ import dataset.dataset as ds
if __name__ == "__main__": if __name__ == "__main__":
checkpoint_path = "log/bigger/version_0/checkpoints/epoch=72-step=360328.ckpt" # checkpoint_path = "log/bigger/version_0/checkpoints/epoch=72-step=360328.ckpt"
# checkpoint_path = "log/bigger/version_4/checkpoints/epoch=81-step=64288.ckpt"
checkpoint_path = "log/bigger/version_8/checkpoints/epoch=14-step=67455.ckpt"
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path) qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
qwen.eval() qwen.eval()
@ -19,18 +21,18 @@ if __name__ == "__main__":
runner = ModelRunner(qwen.llm) runner = ModelRunner(qwen.llm)
# batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7, 14, 13, 1, 12, 13]], dtype=torch.int64)
# sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
# print(sorted_logits.detach().cpu().numpy())
# print(sorted_indices.detach().cpu().numpy())
val = ds.InitValDataset(conf).dataset val = ds.InitValDataset(conf).dataset
md = val.meaning_dataset md = val.meaning_dataset
map = md.get_meaning_map() map = md.get_meaning_map()
item = md.get_token(0)
node = map.get_nodetree(md.get_meaning(0)) # seq:844
# node.print() # seq:849
# seq:991
# seq:995
node = map.get_nodetree(995)
item, l, rank_idx, rank_all = map.get_sequence(995)
print("len of seq:" + str(len(item)))
for i in range(1, len(item)): for i in range(1, len(item)):
itemm = [item[:i]] itemm = [item[:i]]

View File

@ -15,7 +15,7 @@ import dataset.dataset as ds
if __name__ == "__main__": if __name__ == "__main__":
checkpoint_path = "log/bigger/version_0/checkpoints/epoch=72-step=360328.ckpt" checkpoint_path = "log/bigger/version_8/checkpoints/epoch=14-step=67455.ckpt"
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path) qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
qwen.eval() qwen.eval()
@ -25,6 +25,7 @@ if __name__ == "__main__":
runner = ModelRunner(qwen.llm) runner = ModelRunner(qwen.llm)
def DumpQK(query, key, causal_mask, index): def DumpQK(query, key, causal_mask, index):
global relation_table
size = query.shape[2] size = query.shape[2]
scale_factor = 1 / math.sqrt(query.size(-1)) scale_factor = 1 / math.sqrt(query.size(-1))
attn_weight = query @ key.transpose(-2, -1) * scale_factor attn_weight = query @ key.transpose(-2, -1) * scale_factor
@ -35,28 +36,44 @@ if __name__ == "__main__":
attn_weight = attn_weight * attn_mask attn_weight = attn_weight * attn_mask
qk = attn_weight[0] qk = attn_weight[0]
prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(index) + ".png" prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(index) + ".png"
qk = qk.cpu()
qk = torch.cat((qk, relation_table.unsqueeze(0)), dim=0)
show.DumpTensorToImage(qk, prePath, GridValue=255) show.DumpTensorToImage(qk, prePath, GridValue=255)
# qk_seq.append(qk) # qk_seq.append(qk)
# qk_index = size # qk_index = size
qwen.llm.hook_attention = DumpQK qwen.llm.hook_attention = DumpQK
val = ds.InitValDataset(conf).dataset val = ds.InitValDataset(conf).dataset
md = val.meaning_dataset md = val.meaning_dataset
map = md.get_meaning_map() map = md.get_meaning_map()
item = md.get_token(0)
node = map.get_nodetree(md.get_meaning(0)) # seq:844
# node.print() # seq:849
# seq:991
# seq:995
meaning = 995
node = map.get_nodetree(meaning)
current_to_common, common_to_current = map.get_level_change(meaning)
node.print()
print(current_to_common)
print(common_to_current)
relation_table = map.get_relation_table(meaning)
# prePath = "./temp/" + "q@k_seq_" + "_layer_" + ".png"
# show.DumpTensorToImage(relation_table, prePath, GridValue=255)
relation_table = torch.tensor(relation_table)
item, level, rank_idx, rank_all = map.get_sequence(meaning)
print(item)
print(level)
print(rank_idx)
print(rank_all)
print("len of seq:" + str(len(item)))
batch = torch.tensor([item], dtype=torch.int64) batch = torch.tensor([item], dtype=torch.int64)
sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False) sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
next_token = sorted_indices.detach().cpu().numpy()[0][0] next_token = sorted_indices.detach().cpu().numpy()[0][0]
node.print()
# batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7]], dtype=torch.int64) # batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7]], dtype=torch.int64)
# sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False) # sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)

View File

@ -25,13 +25,13 @@ if __name__ == "__main__":
loader = train_dataloader.dataset loader = train_dataloader.dataset
map = loader.meaning_dataset.get_meaning_map() map = loader.meaning_dataset.get_meaning_map()
trees = {} seqs = {}
for batch in loader: for batch in loader:
for m in batch["meaning"]: for m in batch["meaning"]:
trees[m] = map.get_tree(m) seqs[m] = map.get_sequence(m)
while True: while True:
m = int(input("input meaning: ")) m = int(input("input meaning: "))
total = 0 total = 0
for tree in trees.values(): for seq in seqs.values():
total = total + tree.count(m) total = total + seq.count(m)
print(f"meaning of {m} count as {total}") print(f"meaning of {m} count as {total}")

View File

@ -20,7 +20,7 @@ if __name__ == "__main__":
conf.learning_rate = 0.001 conf.learning_rate = 0.001
conf.use_tril_attention_mask = None conf.use_tril_attention_mask = None
conf.precision = "16-mixed" # "precision:bf16-mixed,16-mixed,32-true" conf.precision = "16-mixed" # "precision:bf16-mixed,16-mixed,32-true"
conf.train_batch_size = 16 conf.train_batch_size = 32
conf.val_batch_size = 2 conf.val_batch_size = 2
conf.num_proc = 8 conf.num_proc = 8
conf.max_epochs = 1000 conf.max_epochs = 1000
@ -29,18 +29,18 @@ if __name__ == "__main__":
conf.seed = 42 conf.seed = 42
conf.dataloader_works = 2 conf.dataloader_works = 2
conf.dataset.meaning.level_ratio = 5 conf.dataset.meaning.start = 800
conf.dataset.meaning.level = 2 conf.dataset.meaning.size = 200000
conf.dataset.meaning.dataset_level = 5
conf.dataset.meaning.min_subitem = 2 conf.dataset.meaning.min_subitem = 2
conf.dataset.meaning.max_subitem = 4
conf.dataset.meaning.val_mask_level = [0, 1, 2] conf.dataset.meaning.val_mask_level = [0, 1, 2]
conf.dataset.meaning.val_mask_idx = [0, 0, -1] conf.dataset.meaning.val_mask_idx = [0, 0, -1]
config.vocab_size = 32 config.vocab_size = 32
config.hidden_size = 128 # 128 1024 2048 32 config.hidden_size = 256 # 128 1024 2048 32
config.intermediate_size = 256 config.intermediate_size = 512
config.num_hidden_layers = 3 # 6 12 24 3 config.num_hidden_layers = 4 # 6 12 24 3
config.num_attention_heads = 8 # 8 8 16 config.num_attention_heads = 4 # 8 8 16
torch.manual_seed(conf.seed) torch.manual_seed(conf.seed)
np.random.seed(conf.seed) np.random.seed(conf.seed)