Add meaning dataset get_relation_table support and paint to qk image.
This commit is contained in:
parent
d8539b6b2b
commit
927c98e823
|
@ -0,0 +1,34 @@
|
||||||
|
from transformers import AutoTokenizer, AutoModel
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
#Mean Pooling - Take attention mask into account for correct averaging
|
||||||
|
def mean_pooling(model_output, attention_mask):
|
||||||
|
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
|
||||||
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
||||||
|
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
||||||
|
|
||||||
|
|
||||||
|
# Sentences we want sentence embeddings for
|
||||||
|
sentences = ['This is an example sentence', 'Each sentence is converted']
|
||||||
|
|
||||||
|
# Load model from HuggingFace Hub
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
|
||||||
|
model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
|
||||||
|
|
||||||
|
# Tokenize sentences
|
||||||
|
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
|
||||||
|
|
||||||
|
# Compute token embeddings
|
||||||
|
with torch.no_grad():
|
||||||
|
model_output = model(**encoded_input)
|
||||||
|
|
||||||
|
# Perform pooling
|
||||||
|
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
|
||||||
|
|
||||||
|
# Normalize embeddings
|
||||||
|
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
|
||||||
|
|
||||||
|
print("Sentence embeddings:")
|
||||||
|
print(sentence_embeddings)
|
||||||
|
print(sentence_embeddings.cpu().numpy())
|
|
@ -9,7 +9,14 @@ import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def toTensor(tensor):
|
||||||
|
if not torch.is_tensor(tensor):
|
||||||
|
tensor = torch.tensor(tensor)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
def DumpTensorToImage(tensor, name, forceSquare=False, scale=1.0, Contrast=None, GridValue=None, Value2Log=False):
|
def DumpTensorToImage(tensor, name, forceSquare=False, scale=1.0, Contrast=None, GridValue=None, Value2Log=False):
|
||||||
|
tensor = toTensor(tensor)
|
||||||
if len(tensor.shape) != 2 and len(tensor.shape) != 1 and len(tensor.shape) != 3:
|
if len(tensor.shape) != 2 and len(tensor.shape) != 1 and len(tensor.shape) != 3:
|
||||||
raise ("Error input dims")
|
raise ("Error input dims")
|
||||||
if ("." not in name) or (name.split(".")[-1] not in {"jpg", "png", "bmp"}):
|
if ("." not in name) or (name.split(".")[-1] not in {"jpg", "png", "bmp"}):
|
||||||
|
@ -20,6 +27,7 @@ def DumpTensorToImage(tensor, name, forceSquare=False, scale=1.0, Contrast=None,
|
||||||
if len(tensor.shape) == 3:
|
if len(tensor.shape) == 3:
|
||||||
channel = tensor.shape[0]
|
channel = tensor.shape[0]
|
||||||
x = math.ceil((channel) ** 0.5)
|
x = math.ceil((channel) ** 0.5)
|
||||||
|
y = math.ceil((x * x) / channel)
|
||||||
calc = tensor.reshape((channel, tensor.shape[1] * tensor.shape[2]))
|
calc = tensor.reshape((channel, tensor.shape[1] * tensor.shape[2]))
|
||||||
if not Contrast:
|
if not Contrast:
|
||||||
tensormax = calc.max(1)[0]
|
tensormax = calc.max(1)[0]
|
||||||
|
@ -33,11 +41,11 @@ def DumpTensorToImage(tensor, name, forceSquare=False, scale=1.0, Contrast=None,
|
||||||
calc = calc.reshape((channel, tensor.shape[1], tensor.shape[2]))
|
calc = calc.reshape((channel, tensor.shape[1], tensor.shape[2]))
|
||||||
if not GridValue:
|
if not GridValue:
|
||||||
GridValue = 128.0
|
GridValue = 128.0
|
||||||
calc = F.pad(calc, (0, 0, 0, 0, 0, x * x - channel), mode="constant", value=GridValue)
|
calc = F.pad(calc, (0, 0, 0, 0, 0, x * y - channel), mode="constant", value=GridValue)
|
||||||
calc = calc.reshape((x, x, tensor.shape[1], tensor.shape[2]))
|
calc = calc.reshape((y, x, tensor.shape[1], tensor.shape[2]))
|
||||||
calc = F.pad(calc, (0, 1, 0, 1, 0, 0), mode="constant", value=GridValue)
|
calc = F.pad(calc, (0, 1, 0, 1, 0, 0), mode="constant", value=GridValue)
|
||||||
tensor = calc.permute((0, 2, 1, 3))
|
tensor = calc.permute((0, 2, 1, 3))
|
||||||
tensor = tensor.reshape((x * tensor.shape[1], x * tensor.shape[3]))
|
tensor = tensor.reshape((y * tensor.shape[1], x * tensor.shape[3]))
|
||||||
DumpTensorToImage(tensor, name, forceSquare=False, scale=scale, Contrast=[0.0, 255.0], GridValue=GridValue)
|
DumpTensorToImage(tensor, name, forceSquare=False, scale=scale, Contrast=[0.0, 255.0], GridValue=GridValue)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -78,6 +86,7 @@ def DumpTensorToImage(tensor, name, forceSquare=False, scale=1.0, Contrast=None,
|
||||||
|
|
||||||
|
|
||||||
def DumpTensorToLog(tensor, name="log"):
|
def DumpTensorToLog(tensor, name="log"):
|
||||||
|
tensor = toTensor(tensor)
|
||||||
tensor_mean = torch.mean(tensor).cpu().detach().numpy()
|
tensor_mean = torch.mean(tensor).cpu().detach().numpy()
|
||||||
tensor_abs_mean = torch.mean(torch.abs(tensor)).cpu().detach().numpy()
|
tensor_abs_mean = torch.mean(torch.abs(tensor)).cpu().detach().numpy()
|
||||||
tensor_range = (torch.max(tensor) - torch.min(tensor)).cpu().detach().numpy()
|
tensor_range = (torch.max(tensor) - torch.min(tensor)).cpu().detach().numpy()
|
||||||
|
@ -92,6 +101,7 @@ def DumpTensorToLog(tensor, name="log"):
|
||||||
|
|
||||||
|
|
||||||
def DumpTensorToFile(tensor, name="tensor.pt"):
|
def DumpTensorToFile(tensor, name="tensor.pt"):
|
||||||
|
tensor = toTensor(tensor)
|
||||||
torch.save(tensor.cpu(), name)
|
torch.save(tensor.cpu(), name)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -38,10 +38,10 @@ class ModelConfig:
|
||||||
|
|
||||||
class MeaningDatasetConfig:
|
class MeaningDatasetConfig:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.level_ratio = 5
|
self.start = 10000
|
||||||
self.level = 5
|
self.size = 4000
|
||||||
self.dataset_level = 3
|
|
||||||
self.min_subitem = 2
|
self.min_subitem = 2
|
||||||
|
self.max_subitem = 10
|
||||||
self.val_mask_level = None
|
self.val_mask_level = None
|
||||||
self.val_mask_idx = None
|
self.val_mask_idx = None
|
||||||
|
|
||||||
|
|
|
@ -31,12 +31,12 @@ def InitDataset(config):
|
||||||
if config.dataset.name == "meaning":
|
if config.dataset.name == "meaning":
|
||||||
c = config.dataset.meaning
|
c = config.dataset.meaning
|
||||||
vocab = config.model_config.vocab_size
|
vocab = config.model_config.vocab_size
|
||||||
start = vocab * (c.level_ratio**c.level)
|
start = c.start
|
||||||
size = vocab * int((c.level_ratio**c.dataset_level))
|
size = c.size
|
||||||
|
|
||||||
path = "./data/"
|
path = "./data/"
|
||||||
trainfile = path + f"MeaningDataset_train_v{size}_s{start}_s{size}_lr{c.level_ratio}_ms{c.min_subitem}.pt"
|
trainfile = path + f"MeaningDataset_train_v{size}_s{start}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}.pt"
|
||||||
valfile = path + f"MeaningDataset_val_v{size}_s{start}_s{size}_lr{c.level_ratio}_ms{c.min_subitem}.pt"
|
valfile = path + f"MeaningDataset_val_v{size}_s{start}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}.pt"
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
os.mkdir(path)
|
os.mkdir(path)
|
||||||
if os.path.exists(trainfile) and os.path.exists(valfile):
|
if os.path.exists(trainfile) and os.path.exists(valfile):
|
||||||
|
@ -48,7 +48,7 @@ def InitDataset(config):
|
||||||
val_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
|
val_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
|
||||||
print(f"INFO: Load dataset end")
|
print(f"INFO: Load dataset end")
|
||||||
else:
|
else:
|
||||||
raw_dataset = MeaningDataset(start, start + size, vocab, None, c.level_ratio, c.min_subitem)
|
raw_dataset = MeaningDataset(start, start + size, vocab, None, c.max_subitem, c.min_subitem)
|
||||||
raw_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
|
raw_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
|
||||||
train_dataset, val_dataset = raw_dataset.split(0.9)
|
train_dataset, val_dataset = raw_dataset.split(0.9)
|
||||||
torch.save(train_dataset, trainfile)
|
torch.save(train_dataset, trainfile)
|
||||||
|
@ -80,11 +80,11 @@ def InitValDataset(config):
|
||||||
if config.dataset.name == "meaning":
|
if config.dataset.name == "meaning":
|
||||||
c = config.dataset.meaning
|
c = config.dataset.meaning
|
||||||
vocab = config.model_config.vocab_size
|
vocab = config.model_config.vocab_size
|
||||||
start = vocab * (c.level_ratio**c.level)
|
start = c.start
|
||||||
size = vocab * int((c.level_ratio**c.dataset_level))
|
size = c.size
|
||||||
|
|
||||||
path = "./data/"
|
path = "./data/"
|
||||||
valfile = path + f"MeaningDataset_val_v{size}_s{start}_s{size}_lr{c.level_ratio}_ms{c.min_subitem}.pt"
|
valfile = path + f"MeaningDataset_val_v{size}_s{start}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}.pt"
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
os.mkdir(path)
|
os.mkdir(path)
|
||||||
if os.path.exists(valfile):
|
if os.path.exists(valfile):
|
||||||
|
@ -93,7 +93,7 @@ def InitValDataset(config):
|
||||||
val_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
|
val_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
|
||||||
print(f"INFO: Load dataset end")
|
print(f"INFO: Load dataset end")
|
||||||
else:
|
else:
|
||||||
raw_dataset = MeaningDataset(start, start + size, vocab, None, c.level_ratio, c.min_subitem)
|
raw_dataset = MeaningDataset(start, start + size, vocab, None, c.max_subitem, c.min_subitem)
|
||||||
raw_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
|
raw_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
|
||||||
train_dataset, val_dataset = raw_dataset.split(0.9)
|
train_dataset, val_dataset = raw_dataset.split(0.9)
|
||||||
torch.save(val_dataset, valfile)
|
torch.save(val_dataset, valfile)
|
||||||
|
|
|
@ -105,12 +105,13 @@ class MeaningMap:
|
||||||
index = index + 1
|
index = index + 1
|
||||||
|
|
||||||
for i in range(self.vocab_size, size):
|
for i in range(self.vocab_size, size):
|
||||||
m = map[i]
|
m = map[i] # 当前meaning的拆分的分支
|
||||||
m = m[m >= 0] # donot cut off the map such as [0]
|
m = m[m >= 0] # donot cut off the map such as [0]
|
||||||
m_len = len(m)
|
m_len = len(m) # 当前meaning的拆分的分支个数
|
||||||
m_list = m.tolist()
|
m_list = m.tolist()
|
||||||
assert m_list, "map list can not be empty list"
|
assert m_list, "map list can not be empty list"
|
||||||
|
|
||||||
|
# 获取每个子meaning的start和end,并且生成序列组合成当前meaning完整的叶index(<vocab_size)
|
||||||
idx = np.concatenate([np.arange(ms_start[m], ms_end[m]) for m in m_list])
|
idx = np.concatenate([np.arange(ms_start[m], ms_end[m]) for m in m_list])
|
||||||
idxidx = np.concatenate(
|
idxidx = np.concatenate(
|
||||||
[np.ones(l, dtype=np.uint32) * i for i, l in enumerate(ms_end[m_list] - ms_start[m_list])]
|
[np.ones(l, dtype=np.uint32) * i for i, l in enumerate(ms_end[m_list] - ms_start[m_list])]
|
||||||
|
@ -118,16 +119,16 @@ class MeaningMap:
|
||||||
len_ma = len(idx)
|
len_ma = len(idx)
|
||||||
|
|
||||||
end = index + len_ma
|
end = index + len_ma
|
||||||
if ms_data.size < end:
|
if ms_data.size < end: # 超过存储数据结构的大小,扩展一个datastep容量
|
||||||
ms_data = np.concatenate([ms_data, np.zeros((datastep), dtype=np.int32)])
|
ms_data = np.concatenate([ms_data, np.zeros((datastep), dtype=np.int32)])
|
||||||
ms_level = np.concatenate([ms_level, np.zeros((datastep), dtype=np.uint32)])
|
ms_level = np.concatenate([ms_level, np.zeros((datastep), dtype=np.uint32)])
|
||||||
ms_rank_idx = np.concatenate([ms_rank_idx, np.zeros((datastep), dtype=np.uint32)])
|
ms_rank_idx = np.concatenate([ms_rank_idx, np.zeros((datastep), dtype=np.uint32)])
|
||||||
ms_rank_all = np.concatenate([ms_rank_all, np.zeros((datastep), dtype=np.uint32)])
|
ms_rank_all = np.concatenate([ms_rank_all, np.zeros((datastep), dtype=np.uint32)])
|
||||||
|
|
||||||
ms_data[index:end] = ms_data[idx]
|
ms_data[index:end] = ms_data[idx] # 拼接当前meaning的所有token到data数据结构里面
|
||||||
ms_level[index:end] = ms_level[idx] + 1
|
ms_level[index:end] = ms_level[idx] + 1 # 处理level
|
||||||
ms_rank_idx[index:end] = (ms_rank_idx[idx] * 16 + idxidx).astype(np.uint32)
|
ms_rank_idx[index:end] = (ms_rank_idx[idx] * 16 + idxidx).astype(np.uint32) # 处理rank_idx
|
||||||
ms_rank_all[index:end] = (ms_rank_all[idx] * 16 + m_len).astype(np.uint32)
|
ms_rank_all[index:end] = (ms_rank_all[idx] * 16 + m_len).astype(np.uint32) # 处理rank_all
|
||||||
|
|
||||||
ms_start[i] = index
|
ms_start[i] = index
|
||||||
ms_end[i] = end
|
ms_end[i] = end
|
||||||
|
@ -200,21 +201,47 @@ class MeaningMap:
|
||||||
root.seq_node = seqlist
|
root.seq_node = seqlist
|
||||||
return root
|
return root
|
||||||
|
|
||||||
def get_tree(self, meaning):
|
# 返回每个token相对于上一个token的level变化
|
||||||
def get_tree_list(ms_map, meaning, mlist):
|
# 返回两个list,分别表示 common -> current -> common 两个变化的level距离
|
||||||
|
def get_level_change(self, meaning):
|
||||||
|
def level_change(ms_map, meaning, current_to_common, common_to_current):
|
||||||
ms = ms_map[meaning]
|
ms = ms_map[meaning]
|
||||||
mlist.append(meaning)
|
|
||||||
mlist.append(-255) # level down marker
|
|
||||||
for m in ms[ms >= 0].tolist():
|
for m in ms[ms >= 0].tolist():
|
||||||
if m >= self.vocab_size:
|
if m >= self.vocab_size:
|
||||||
get_tree_list(ms_map, m, mlist)
|
common_to_current[-1] = common_to_current[-1] + 1
|
||||||
|
level_change(ms_map, m, current_to_common, common_to_current)
|
||||||
else:
|
else:
|
||||||
mlist.append(m)
|
current_to_common.append(0)
|
||||||
mlist.append(-1) # level up marker
|
common_to_current.append(0)
|
||||||
|
current_to_common[-2] = current_to_common[-2] + 1
|
||||||
|
|
||||||
meaninglist = []
|
common_to_current = []
|
||||||
get_tree_list(self.ms_map, meaning, meaninglist)
|
common_to_current.append(1)
|
||||||
return meaninglist
|
current_to_common = []
|
||||||
|
current_to_common.append(0)
|
||||||
|
level_change(self.ms_map, meaning, current_to_common, common_to_current)
|
||||||
|
current_to_common = current_to_common[:-1]
|
||||||
|
common_to_current = common_to_current[:-1]
|
||||||
|
return current_to_common, common_to_current
|
||||||
|
|
||||||
|
# 根据meaning的层级结构范围一个二位的数组,表示所有token跟前面token是否有关系
|
||||||
|
def get_relation_table(self, meaning):
|
||||||
|
current_to_common, common_to_current = self.get_level_change(meaning)
|
||||||
|
width = len(current_to_common)
|
||||||
|
relation = np.zeros((width, width), dtype=int)
|
||||||
|
relation[0, 0] = 1
|
||||||
|
for i in range(1, width, 1):
|
||||||
|
if i == width - 2:
|
||||||
|
print(1)
|
||||||
|
ori = current_to_common[i] - common_to_current[i]
|
||||||
|
start_index = width
|
||||||
|
for s in range(i - 1, -1, -1):
|
||||||
|
if ori < 0:
|
||||||
|
break
|
||||||
|
ori = ori - common_to_current[s] + current_to_common[s]
|
||||||
|
start_index = s
|
||||||
|
relation[i, start_index : i + 1] = 1.0
|
||||||
|
return relation
|
||||||
|
|
||||||
def max_length(self):
|
def max_length(self):
|
||||||
return max(self.ms_len)
|
return max(self.ms_len)
|
||||||
|
@ -430,6 +457,7 @@ class BatchGroupMeaningDataloader(Dataset):
|
||||||
self, batch_size=1, num_workers=num_workers, collate_fn=BatchGroupMeaningDataloader.detection_collate
|
self, batch_size=1, num_workers=num_workers, collate_fn=BatchGroupMeaningDataloader.detection_collate
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=False)
|
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=False)
|
||||||
|
|
|
@ -12,13 +12,20 @@ meaning数据集是一个模仿自然语言,以及抽象表达的数据集。
|
||||||
6. level表示当前token相对于root meaning的距离
|
6. level表示当前token相对于root meaning的距离
|
||||||
7. rank
|
7. rank
|
||||||
8. rank_idx表示当前token在不同层的排序编号,每4位表示在一层里面的编号,低4位表示最低层级的rank_idx,高位无用的位用1填充
|
8. rank_idx表示当前token在不同层的排序编号,每4位表示在一层里面的编号,低4位表示最低层级的rank_idx,高位无用的位用1填充
|
||||||
9. rank_all表示当前token在不同层的分子个数,每4位表示在一层里面的编号,低4位表示最低层级的rank_all,高位无用的位用1填充
|
9. rank_all表示当前token所在的不同层的总的分支个数,每4位表示在一层里面的个数,低4位表示最低层级的rank_all,高位无用的位用1填充
|
||||||
10. tree用于存储每个meaning的拆解的数据,使用字典表达一个树形结构
|
10. tree用于存储每个meaning的拆解的数据,使用字典表达一个树形结构
|
||||||
11. get_seq_mask返回一个sequence每个token在对应level是不是对应的index,level=0:最底层,index=-1:最后一个,index=0:第一个
|
11. get_seq_mask返回一个sequence每个token在对应level是不是对应的index,level=0:最底层,index=-1:最后一个,index=0:第一个
|
||||||
12. meaning_height 当前meaning的总高度
|
12. meaning_height 当前meaning的总高度
|
||||||
13. meaning_weight 当前meaning的总宽度
|
13. meaning_weight 当前meaning的总宽度
|
||||||
14. val_mask_level val_mask_idx: 表示用于训练的token的mask,val_mask_level=[0, 1, 2] val_mask_idx=[0, 0, -1]表示只有是第0层第0个,而且是第1层第0个,第2层最后一个的token,才参与训练
|
14. val_mask_level val_mask_idx: 表示用于训练的token的mask,val_mask_level=[0, 1, 2] val_mask_idx=[0, 0, -1]表示只有是第0层第0个,而且是第1层第0个,第2层最后一个的token,才参与训练
|
||||||
|
|
||||||
|
## code
|
||||||
|
|
||||||
|
```
|
||||||
|
vocab = config.model_config.vocab_size
|
||||||
|
start 数据集的样本开始meaning
|
||||||
|
size 数据集的样本个数
|
||||||
|
```
|
||||||
|
|
||||||
```
|
```
|
||||||
vocab_size = 256 meaning = 115200
|
vocab_size = 256 meaning = 115200
|
||||||
|
@ -33,11 +40,13 @@ vocab_size = 256 meaning = 115200
|
||||||
/ \ / \ / \ / \
|
/ \ / \ / \ / \
|
||||||
176 11 255 129 129 99 211 111
|
176 11 255 129 129 99 211 111
|
||||||
|
|
||||||
sequence = 123 42 32 176 11 255 129 245 233 129 99 23 211 111 93 176
|
sequence = 123 42 32 176 11 255 129 245 233 129 99 23 211 111 93 176
|
||||||
level = 3 3 2 4 4 4 4 2 2 4 4 3 4 4 3 3
|
level = 3 3 2 4 4 4 4 2 2 4 4 3 4 4 3 3
|
||||||
idx at 0 = 0 1 1 0 1 0 1 0 1 0 1 2 0 1 0 1
|
|
||||||
idx at 1 = 0 0 0 0 0 1 1 1 1 0 0 0 0 0 2 2
|
|
||||||
idx 0 1 1 0 1 16 17 16 17 0 1 2 0 1 32 33
|
|
||||||
|
|
||||||
|
rank_idx = 0 1 1 0 1 16 17 16 17 0 1 2 0 1 32 33
|
||||||
|
idx at0 = 0 1 1 0 1 0 1 0 1 0 1 2 0 1 0 1
|
||||||
|
idx at1 = 0 0 0 0 0 1 1 1 1 0 0 0 0 0 2 2
|
||||||
|
|
||||||
|
rank_all =
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
Binary file not shown.
After Width: | Height: | Size: 3.8 KiB |
|
@ -10,4 +10,11 @@
|
||||||
## 不同模型深度对结果的影响
|
## 不同模型深度对结果的影响
|
||||||
|
|
||||||
6层相对于3层没有提升的原因,可能是数据集太小,3层已经能完全拟合
|
6层相对于3层没有提升的原因,可能是数据集太小,3层已经能完全拟合
|
||||||

|

|
||||||
|
|
||||||
|
## qk图解释
|
||||||
|
|
||||||
|
1. key[10] = 1000.0
|
||||||
|
2. 每一行数据(像素)表示一个新的token,和前面所有token的关系
|
||||||
|
|
||||||
|

|
|
@ -8,7 +8,9 @@ import dataset.dataset as ds
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
checkpoint_path = "log/bigger/version_0/checkpoints/epoch=72-step=360328.ckpt"
|
# checkpoint_path = "log/bigger/version_0/checkpoints/epoch=72-step=360328.ckpt"
|
||||||
|
# checkpoint_path = "log/bigger/version_4/checkpoints/epoch=81-step=64288.ckpt"
|
||||||
|
checkpoint_path = "log/bigger/version_8/checkpoints/epoch=14-step=67455.ckpt"
|
||||||
|
|
||||||
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
||||||
qwen.eval()
|
qwen.eval()
|
||||||
|
@ -19,18 +21,18 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
runner = ModelRunner(qwen.llm)
|
runner = ModelRunner(qwen.llm)
|
||||||
|
|
||||||
# batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7, 14, 13, 1, 12, 13]], dtype=torch.int64)
|
|
||||||
# sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
|
|
||||||
# print(sorted_logits.detach().cpu().numpy())
|
|
||||||
# print(sorted_indices.detach().cpu().numpy())
|
|
||||||
|
|
||||||
val = ds.InitValDataset(conf).dataset
|
val = ds.InitValDataset(conf).dataset
|
||||||
md = val.meaning_dataset
|
md = val.meaning_dataset
|
||||||
map = md.get_meaning_map()
|
map = md.get_meaning_map()
|
||||||
item = md.get_token(0)
|
|
||||||
|
|
||||||
node = map.get_nodetree(md.get_meaning(0))
|
# seq:844
|
||||||
# node.print()
|
# seq:849
|
||||||
|
# seq:991
|
||||||
|
# seq:995
|
||||||
|
|
||||||
|
node = map.get_nodetree(995)
|
||||||
|
item, l, rank_idx, rank_all = map.get_sequence(995)
|
||||||
|
print("len of seq:" + str(len(item)))
|
||||||
|
|
||||||
for i in range(1, len(item)):
|
for i in range(1, len(item)):
|
||||||
itemm = [item[:i]]
|
itemm = [item[:i]]
|
||||||
|
|
|
@ -15,7 +15,7 @@ import dataset.dataset as ds
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
checkpoint_path = "log/bigger/version_0/checkpoints/epoch=72-step=360328.ckpt"
|
checkpoint_path = "log/bigger/version_8/checkpoints/epoch=14-step=67455.ckpt"
|
||||||
|
|
||||||
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
||||||
qwen.eval()
|
qwen.eval()
|
||||||
|
@ -25,6 +25,7 @@ if __name__ == "__main__":
|
||||||
runner = ModelRunner(qwen.llm)
|
runner = ModelRunner(qwen.llm)
|
||||||
|
|
||||||
def DumpQK(query, key, causal_mask, index):
|
def DumpQK(query, key, causal_mask, index):
|
||||||
|
global relation_table
|
||||||
size = query.shape[2]
|
size = query.shape[2]
|
||||||
scale_factor = 1 / math.sqrt(query.size(-1))
|
scale_factor = 1 / math.sqrt(query.size(-1))
|
||||||
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
||||||
|
@ -35,28 +36,44 @@ if __name__ == "__main__":
|
||||||
attn_weight = attn_weight * attn_mask
|
attn_weight = attn_weight * attn_mask
|
||||||
qk = attn_weight[0]
|
qk = attn_weight[0]
|
||||||
prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(index) + ".png"
|
prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(index) + ".png"
|
||||||
|
qk = qk.cpu()
|
||||||
|
qk = torch.cat((qk, relation_table.unsqueeze(0)), dim=0)
|
||||||
show.DumpTensorToImage(qk, prePath, GridValue=255)
|
show.DumpTensorToImage(qk, prePath, GridValue=255)
|
||||||
# qk_seq.append(qk)
|
# qk_seq.append(qk)
|
||||||
# qk_index = size
|
# qk_index = size
|
||||||
|
|
||||||
qwen.llm.hook_attention = DumpQK
|
qwen.llm.hook_attention = DumpQK
|
||||||
|
|
||||||
|
|
||||||
val = ds.InitValDataset(conf).dataset
|
val = ds.InitValDataset(conf).dataset
|
||||||
md = val.meaning_dataset
|
md = val.meaning_dataset
|
||||||
map = md.get_meaning_map()
|
map = md.get_meaning_map()
|
||||||
item = md.get_token(0)
|
|
||||||
|
|
||||||
node = map.get_nodetree(md.get_meaning(0))
|
# seq:844
|
||||||
# node.print()
|
# seq:849
|
||||||
|
# seq:991
|
||||||
|
# seq:995
|
||||||
|
meaning = 995
|
||||||
|
node = map.get_nodetree(meaning)
|
||||||
|
current_to_common, common_to_current = map.get_level_change(meaning)
|
||||||
|
|
||||||
|
node.print()
|
||||||
|
print(current_to_common)
|
||||||
|
print(common_to_current)
|
||||||
|
relation_table = map.get_relation_table(meaning)
|
||||||
|
# prePath = "./temp/" + "q@k_seq_" + "_layer_" + ".png"
|
||||||
|
# show.DumpTensorToImage(relation_table, prePath, GridValue=255)
|
||||||
|
relation_table = torch.tensor(relation_table)
|
||||||
|
|
||||||
|
item, level, rank_idx, rank_all = map.get_sequence(meaning)
|
||||||
|
print(item)
|
||||||
|
print(level)
|
||||||
|
print(rank_idx)
|
||||||
|
print(rank_all)
|
||||||
|
print("len of seq:" + str(len(item)))
|
||||||
|
|
||||||
batch = torch.tensor([item], dtype=torch.int64)
|
batch = torch.tensor([item], dtype=torch.int64)
|
||||||
sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
|
sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
|
||||||
next_token = sorted_indices.detach().cpu().numpy()[0][0]
|
next_token = sorted_indices.detach().cpu().numpy()[0][0]
|
||||||
node.print()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7]], dtype=torch.int64)
|
# batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7]], dtype=torch.int64)
|
||||||
# sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
|
# sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
|
||||||
|
|
|
@ -25,13 +25,13 @@ if __name__ == "__main__":
|
||||||
loader = train_dataloader.dataset
|
loader = train_dataloader.dataset
|
||||||
|
|
||||||
map = loader.meaning_dataset.get_meaning_map()
|
map = loader.meaning_dataset.get_meaning_map()
|
||||||
trees = {}
|
seqs = {}
|
||||||
for batch in loader:
|
for batch in loader:
|
||||||
for m in batch["meaning"]:
|
for m in batch["meaning"]:
|
||||||
trees[m] = map.get_tree(m)
|
seqs[m] = map.get_sequence(m)
|
||||||
while True:
|
while True:
|
||||||
m = int(input("input meaning: "))
|
m = int(input("input meaning: "))
|
||||||
total = 0
|
total = 0
|
||||||
for tree in trees.values():
|
for seq in seqs.values():
|
||||||
total = total + tree.count(m)
|
total = total + seq.count(m)
|
||||||
print(f"meaning of {m} count as {total}")
|
print(f"meaning of {m} count as {total}")
|
||||||
|
|
16
wit/train.py
16
wit/train.py
|
@ -20,7 +20,7 @@ if __name__ == "__main__":
|
||||||
conf.learning_rate = 0.001
|
conf.learning_rate = 0.001
|
||||||
conf.use_tril_attention_mask = None
|
conf.use_tril_attention_mask = None
|
||||||
conf.precision = "16-mixed" # "precision:bf16-mixed,16-mixed,32-true"
|
conf.precision = "16-mixed" # "precision:bf16-mixed,16-mixed,32-true"
|
||||||
conf.train_batch_size = 16
|
conf.train_batch_size = 32
|
||||||
conf.val_batch_size = 2
|
conf.val_batch_size = 2
|
||||||
conf.num_proc = 8
|
conf.num_proc = 8
|
||||||
conf.max_epochs = 1000
|
conf.max_epochs = 1000
|
||||||
|
@ -29,18 +29,18 @@ if __name__ == "__main__":
|
||||||
conf.seed = 42
|
conf.seed = 42
|
||||||
conf.dataloader_works = 2
|
conf.dataloader_works = 2
|
||||||
|
|
||||||
conf.dataset.meaning.level_ratio = 5
|
conf.dataset.meaning.start = 800
|
||||||
conf.dataset.meaning.level = 2
|
conf.dataset.meaning.size = 200000
|
||||||
conf.dataset.meaning.dataset_level = 5
|
|
||||||
conf.dataset.meaning.min_subitem = 2
|
conf.dataset.meaning.min_subitem = 2
|
||||||
|
conf.dataset.meaning.max_subitem = 4
|
||||||
conf.dataset.meaning.val_mask_level = [0, 1, 2]
|
conf.dataset.meaning.val_mask_level = [0, 1, 2]
|
||||||
conf.dataset.meaning.val_mask_idx = [0, 0, -1]
|
conf.dataset.meaning.val_mask_idx = [0, 0, -1]
|
||||||
|
|
||||||
config.vocab_size = 32
|
config.vocab_size = 32
|
||||||
config.hidden_size = 128 # 128 1024 2048 32
|
config.hidden_size = 256 # 128 1024 2048 32
|
||||||
config.intermediate_size = 256
|
config.intermediate_size = 512
|
||||||
config.num_hidden_layers = 3 # 6 12 24 3
|
config.num_hidden_layers = 4 # 6 12 24 3
|
||||||
config.num_attention_heads = 8 # 8 8 16
|
config.num_attention_heads = 4 # 8 8 16
|
||||||
|
|
||||||
torch.manual_seed(conf.seed)
|
torch.manual_seed(conf.seed)
|
||||||
np.random.seed(conf.seed)
|
np.random.seed(conf.seed)
|
||||||
|
|
Loading…
Reference in New Issue