Add meaning dataset get_relation_table support and paint to qk image.
This commit is contained in:
parent
d8539b6b2b
commit
927c98e823
|
@ -0,0 +1,34 @@
|
|||
from transformers import AutoTokenizer, AutoModel
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
#Mean Pooling - Take attention mask into account for correct averaging
|
||||
def mean_pooling(model_output, attention_mask):
|
||||
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
|
||||
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
||||
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
||||
|
||||
|
||||
# Sentences we want sentence embeddings for
|
||||
sentences = ['This is an example sentence', 'Each sentence is converted']
|
||||
|
||||
# Load model from HuggingFace Hub
|
||||
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
|
||||
model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
|
||||
|
||||
# Tokenize sentences
|
||||
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
|
||||
|
||||
# Compute token embeddings
|
||||
with torch.no_grad():
|
||||
model_output = model(**encoded_input)
|
||||
|
||||
# Perform pooling
|
||||
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
|
||||
|
||||
# Normalize embeddings
|
||||
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
|
||||
|
||||
print("Sentence embeddings:")
|
||||
print(sentence_embeddings)
|
||||
print(sentence_embeddings.cpu().numpy())
|
|
@ -9,7 +9,14 @@ import os
|
|||
from pathlib import Path
|
||||
|
||||
|
||||
def toTensor(tensor):
|
||||
if not torch.is_tensor(tensor):
|
||||
tensor = torch.tensor(tensor)
|
||||
return tensor
|
||||
|
||||
|
||||
def DumpTensorToImage(tensor, name, forceSquare=False, scale=1.0, Contrast=None, GridValue=None, Value2Log=False):
|
||||
tensor = toTensor(tensor)
|
||||
if len(tensor.shape) != 2 and len(tensor.shape) != 1 and len(tensor.shape) != 3:
|
||||
raise ("Error input dims")
|
||||
if ("." not in name) or (name.split(".")[-1] not in {"jpg", "png", "bmp"}):
|
||||
|
@ -20,6 +27,7 @@ def DumpTensorToImage(tensor, name, forceSquare=False, scale=1.0, Contrast=None,
|
|||
if len(tensor.shape) == 3:
|
||||
channel = tensor.shape[0]
|
||||
x = math.ceil((channel) ** 0.5)
|
||||
y = math.ceil((x * x) / channel)
|
||||
calc = tensor.reshape((channel, tensor.shape[1] * tensor.shape[2]))
|
||||
if not Contrast:
|
||||
tensormax = calc.max(1)[0]
|
||||
|
@ -33,11 +41,11 @@ def DumpTensorToImage(tensor, name, forceSquare=False, scale=1.0, Contrast=None,
|
|||
calc = calc.reshape((channel, tensor.shape[1], tensor.shape[2]))
|
||||
if not GridValue:
|
||||
GridValue = 128.0
|
||||
calc = F.pad(calc, (0, 0, 0, 0, 0, x * x - channel), mode="constant", value=GridValue)
|
||||
calc = calc.reshape((x, x, tensor.shape[1], tensor.shape[2]))
|
||||
calc = F.pad(calc, (0, 0, 0, 0, 0, x * y - channel), mode="constant", value=GridValue)
|
||||
calc = calc.reshape((y, x, tensor.shape[1], tensor.shape[2]))
|
||||
calc = F.pad(calc, (0, 1, 0, 1, 0, 0), mode="constant", value=GridValue)
|
||||
tensor = calc.permute((0, 2, 1, 3))
|
||||
tensor = tensor.reshape((x * tensor.shape[1], x * tensor.shape[3]))
|
||||
tensor = tensor.reshape((y * tensor.shape[1], x * tensor.shape[3]))
|
||||
DumpTensorToImage(tensor, name, forceSquare=False, scale=scale, Contrast=[0.0, 255.0], GridValue=GridValue)
|
||||
return
|
||||
|
||||
|
@ -78,6 +86,7 @@ def DumpTensorToImage(tensor, name, forceSquare=False, scale=1.0, Contrast=None,
|
|||
|
||||
|
||||
def DumpTensorToLog(tensor, name="log"):
|
||||
tensor = toTensor(tensor)
|
||||
tensor_mean = torch.mean(tensor).cpu().detach().numpy()
|
||||
tensor_abs_mean = torch.mean(torch.abs(tensor)).cpu().detach().numpy()
|
||||
tensor_range = (torch.max(tensor) - torch.min(tensor)).cpu().detach().numpy()
|
||||
|
@ -92,6 +101,7 @@ def DumpTensorToLog(tensor, name="log"):
|
|||
|
||||
|
||||
def DumpTensorToFile(tensor, name="tensor.pt"):
|
||||
tensor = toTensor(tensor)
|
||||
torch.save(tensor.cpu(), name)
|
||||
|
||||
|
||||
|
|
|
@ -38,10 +38,10 @@ class ModelConfig:
|
|||
|
||||
class MeaningDatasetConfig:
|
||||
def __init__(self):
|
||||
self.level_ratio = 5
|
||||
self.level = 5
|
||||
self.dataset_level = 3
|
||||
self.start = 10000
|
||||
self.size = 4000
|
||||
self.min_subitem = 2
|
||||
self.max_subitem = 10
|
||||
self.val_mask_level = None
|
||||
self.val_mask_idx = None
|
||||
|
||||
|
|
|
@ -31,12 +31,12 @@ def InitDataset(config):
|
|||
if config.dataset.name == "meaning":
|
||||
c = config.dataset.meaning
|
||||
vocab = config.model_config.vocab_size
|
||||
start = vocab * (c.level_ratio**c.level)
|
||||
size = vocab * int((c.level_ratio**c.dataset_level))
|
||||
start = c.start
|
||||
size = c.size
|
||||
|
||||
path = "./data/"
|
||||
trainfile = path + f"MeaningDataset_train_v{size}_s{start}_s{size}_lr{c.level_ratio}_ms{c.min_subitem}.pt"
|
||||
valfile = path + f"MeaningDataset_val_v{size}_s{start}_s{size}_lr{c.level_ratio}_ms{c.min_subitem}.pt"
|
||||
trainfile = path + f"MeaningDataset_train_v{size}_s{start}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}.pt"
|
||||
valfile = path + f"MeaningDataset_val_v{size}_s{start}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}.pt"
|
||||
if not os.path.exists(path):
|
||||
os.mkdir(path)
|
||||
if os.path.exists(trainfile) and os.path.exists(valfile):
|
||||
|
@ -48,7 +48,7 @@ def InitDataset(config):
|
|||
val_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
|
||||
print(f"INFO: Load dataset end")
|
||||
else:
|
||||
raw_dataset = MeaningDataset(start, start + size, vocab, None, c.level_ratio, c.min_subitem)
|
||||
raw_dataset = MeaningDataset(start, start + size, vocab, None, c.max_subitem, c.min_subitem)
|
||||
raw_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
|
||||
train_dataset, val_dataset = raw_dataset.split(0.9)
|
||||
torch.save(train_dataset, trainfile)
|
||||
|
@ -80,11 +80,11 @@ def InitValDataset(config):
|
|||
if config.dataset.name == "meaning":
|
||||
c = config.dataset.meaning
|
||||
vocab = config.model_config.vocab_size
|
||||
start = vocab * (c.level_ratio**c.level)
|
||||
size = vocab * int((c.level_ratio**c.dataset_level))
|
||||
start = c.start
|
||||
size = c.size
|
||||
|
||||
path = "./data/"
|
||||
valfile = path + f"MeaningDataset_val_v{size}_s{start}_s{size}_lr{c.level_ratio}_ms{c.min_subitem}.pt"
|
||||
valfile = path + f"MeaningDataset_val_v{size}_s{start}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}.pt"
|
||||
if not os.path.exists(path):
|
||||
os.mkdir(path)
|
||||
if os.path.exists(valfile):
|
||||
|
@ -93,7 +93,7 @@ def InitValDataset(config):
|
|||
val_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
|
||||
print(f"INFO: Load dataset end")
|
||||
else:
|
||||
raw_dataset = MeaningDataset(start, start + size, vocab, None, c.level_ratio, c.min_subitem)
|
||||
raw_dataset = MeaningDataset(start, start + size, vocab, None, c.max_subitem, c.min_subitem)
|
||||
raw_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
|
||||
train_dataset, val_dataset = raw_dataset.split(0.9)
|
||||
torch.save(val_dataset, valfile)
|
||||
|
|
|
@ -105,12 +105,13 @@ class MeaningMap:
|
|||
index = index + 1
|
||||
|
||||
for i in range(self.vocab_size, size):
|
||||
m = map[i]
|
||||
m = map[i] # 当前meaning的拆分的分支
|
||||
m = m[m >= 0] # donot cut off the map such as [0]
|
||||
m_len = len(m)
|
||||
m_len = len(m) # 当前meaning的拆分的分支个数
|
||||
m_list = m.tolist()
|
||||
assert m_list, "map list can not be empty list"
|
||||
|
||||
# 获取每个子meaning的start和end,并且生成序列组合成当前meaning完整的叶index(<vocab_size)
|
||||
idx = np.concatenate([np.arange(ms_start[m], ms_end[m]) for m in m_list])
|
||||
idxidx = np.concatenate(
|
||||
[np.ones(l, dtype=np.uint32) * i for i, l in enumerate(ms_end[m_list] - ms_start[m_list])]
|
||||
|
@ -118,16 +119,16 @@ class MeaningMap:
|
|||
len_ma = len(idx)
|
||||
|
||||
end = index + len_ma
|
||||
if ms_data.size < end:
|
||||
if ms_data.size < end: # 超过存储数据结构的大小,扩展一个datastep容量
|
||||
ms_data = np.concatenate([ms_data, np.zeros((datastep), dtype=np.int32)])
|
||||
ms_level = np.concatenate([ms_level, np.zeros((datastep), dtype=np.uint32)])
|
||||
ms_rank_idx = np.concatenate([ms_rank_idx, np.zeros((datastep), dtype=np.uint32)])
|
||||
ms_rank_all = np.concatenate([ms_rank_all, np.zeros((datastep), dtype=np.uint32)])
|
||||
|
||||
ms_data[index:end] = ms_data[idx]
|
||||
ms_level[index:end] = ms_level[idx] + 1
|
||||
ms_rank_idx[index:end] = (ms_rank_idx[idx] * 16 + idxidx).astype(np.uint32)
|
||||
ms_rank_all[index:end] = (ms_rank_all[idx] * 16 + m_len).astype(np.uint32)
|
||||
ms_data[index:end] = ms_data[idx] # 拼接当前meaning的所有token到data数据结构里面
|
||||
ms_level[index:end] = ms_level[idx] + 1 # 处理level
|
||||
ms_rank_idx[index:end] = (ms_rank_idx[idx] * 16 + idxidx).astype(np.uint32) # 处理rank_idx
|
||||
ms_rank_all[index:end] = (ms_rank_all[idx] * 16 + m_len).astype(np.uint32) # 处理rank_all
|
||||
|
||||
ms_start[i] = index
|
||||
ms_end[i] = end
|
||||
|
@ -200,21 +201,47 @@ class MeaningMap:
|
|||
root.seq_node = seqlist
|
||||
return root
|
||||
|
||||
def get_tree(self, meaning):
|
||||
def get_tree_list(ms_map, meaning, mlist):
|
||||
# 返回每个token相对于上一个token的level变化
|
||||
# 返回两个list,分别表示 common -> current -> common 两个变化的level距离
|
||||
def get_level_change(self, meaning):
|
||||
def level_change(ms_map, meaning, current_to_common, common_to_current):
|
||||
ms = ms_map[meaning]
|
||||
mlist.append(meaning)
|
||||
mlist.append(-255) # level down marker
|
||||
for m in ms[ms >= 0].tolist():
|
||||
if m >= self.vocab_size:
|
||||
get_tree_list(ms_map, m, mlist)
|
||||
common_to_current[-1] = common_to_current[-1] + 1
|
||||
level_change(ms_map, m, current_to_common, common_to_current)
|
||||
else:
|
||||
mlist.append(m)
|
||||
mlist.append(-1) # level up marker
|
||||
current_to_common.append(0)
|
||||
common_to_current.append(0)
|
||||
current_to_common[-2] = current_to_common[-2] + 1
|
||||
|
||||
meaninglist = []
|
||||
get_tree_list(self.ms_map, meaning, meaninglist)
|
||||
return meaninglist
|
||||
common_to_current = []
|
||||
common_to_current.append(1)
|
||||
current_to_common = []
|
||||
current_to_common.append(0)
|
||||
level_change(self.ms_map, meaning, current_to_common, common_to_current)
|
||||
current_to_common = current_to_common[:-1]
|
||||
common_to_current = common_to_current[:-1]
|
||||
return current_to_common, common_to_current
|
||||
|
||||
# 根据meaning的层级结构范围一个二位的数组,表示所有token跟前面token是否有关系
|
||||
def get_relation_table(self, meaning):
|
||||
current_to_common, common_to_current = self.get_level_change(meaning)
|
||||
width = len(current_to_common)
|
||||
relation = np.zeros((width, width), dtype=int)
|
||||
relation[0, 0] = 1
|
||||
for i in range(1, width, 1):
|
||||
if i == width - 2:
|
||||
print(1)
|
||||
ori = current_to_common[i] - common_to_current[i]
|
||||
start_index = width
|
||||
for s in range(i - 1, -1, -1):
|
||||
if ori < 0:
|
||||
break
|
||||
ori = ori - common_to_current[s] + current_to_common[s]
|
||||
start_index = s
|
||||
relation[i, start_index : i + 1] = 1.0
|
||||
return relation
|
||||
|
||||
def max_length(self):
|
||||
return max(self.ms_len)
|
||||
|
@ -430,6 +457,7 @@ class BatchGroupMeaningDataloader(Dataset):
|
|||
self, batch_size=1, num_workers=num_workers, collate_fn=BatchGroupMeaningDataloader.detection_collate
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=False)
|
||||
|
|
|
@ -12,13 +12,20 @@ meaning数据集是一个模仿自然语言,以及抽象表达的数据集。
|
|||
6. level表示当前token相对于root meaning的距离
|
||||
7. rank
|
||||
8. rank_idx表示当前token在不同层的排序编号,每4位表示在一层里面的编号,低4位表示最低层级的rank_idx,高位无用的位用1填充
|
||||
9. rank_all表示当前token在不同层的分子个数,每4位表示在一层里面的编号,低4位表示最低层级的rank_all,高位无用的位用1填充
|
||||
9. rank_all表示当前token所在的不同层的总的分支个数,每4位表示在一层里面的个数,低4位表示最低层级的rank_all,高位无用的位用1填充
|
||||
10. tree用于存储每个meaning的拆解的数据,使用字典表达一个树形结构
|
||||
11. get_seq_mask返回一个sequence每个token在对应level是不是对应的index,level=0:最底层,index=-1:最后一个,index=0:第一个
|
||||
12. meaning_height 当前meaning的总高度
|
||||
13. meaning_weight 当前meaning的总宽度
|
||||
14. val_mask_level val_mask_idx: 表示用于训练的token的mask,val_mask_level=[0, 1, 2] val_mask_idx=[0, 0, -1]表示只有是第0层第0个,而且是第1层第0个,第2层最后一个的token,才参与训练
|
||||
|
||||
## code
|
||||
|
||||
```
|
||||
vocab = config.model_config.vocab_size
|
||||
start 数据集的样本开始meaning
|
||||
size 数据集的样本个数
|
||||
```
|
||||
|
||||
```
|
||||
vocab_size = 256 meaning = 115200
|
||||
|
@ -33,11 +40,13 @@ vocab_size = 256 meaning = 115200
|
|||
/ \ / \ / \ / \
|
||||
176 11 255 129 129 99 211 111
|
||||
|
||||
sequence = 123 42 32 176 11 255 129 245 233 129 99 23 211 111 93 176
|
||||
level = 3 3 2 4 4 4 4 2 2 4 4 3 4 4 3 3
|
||||
idx at 0 = 0 1 1 0 1 0 1 0 1 0 1 2 0 1 0 1
|
||||
idx at 1 = 0 0 0 0 0 1 1 1 1 0 0 0 0 0 2 2
|
||||
idx 0 1 1 0 1 16 17 16 17 0 1 2 0 1 32 33
|
||||
sequence = 123 42 32 176 11 255 129 245 233 129 99 23 211 111 93 176
|
||||
level = 3 3 2 4 4 4 4 2 2 4 4 3 4 4 3 3
|
||||
|
||||
rank_idx = 0 1 1 0 1 16 17 16 17 0 1 2 0 1 32 33
|
||||
idx at0 = 0 1 1 0 1 0 1 0 1 0 1 2 0 1 0 1
|
||||
idx at1 = 0 0 0 0 0 1 1 1 1 0 0 0 0 0 2 2
|
||||
|
||||
rank_all =
|
||||
|
||||
```
|
||||
|
|
Binary file not shown.
After Width: | Height: | Size: 3.8 KiB |
|
@ -10,4 +10,11 @@
|
|||
## 不同模型深度对结果的影响
|
||||
|
||||
6层相对于3层没有提升的原因,可能是数据集太小,3层已经能完全拟合
|
||||

|
||||

|
||||
|
||||
## qk图解释
|
||||
|
||||
1. key[10] = 1000.0
|
||||
2. 每一行数据(像素)表示一个新的token,和前面所有token的关系
|
||||
|
||||

|
|
@ -8,7 +8,9 @@ import dataset.dataset as ds
|
|||
|
||||
if __name__ == "__main__":
|
||||
|
||||
checkpoint_path = "log/bigger/version_0/checkpoints/epoch=72-step=360328.ckpt"
|
||||
# checkpoint_path = "log/bigger/version_0/checkpoints/epoch=72-step=360328.ckpt"
|
||||
# checkpoint_path = "log/bigger/version_4/checkpoints/epoch=81-step=64288.ckpt"
|
||||
checkpoint_path = "log/bigger/version_8/checkpoints/epoch=14-step=67455.ckpt"
|
||||
|
||||
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
||||
qwen.eval()
|
||||
|
@ -19,18 +21,18 @@ if __name__ == "__main__":
|
|||
|
||||
runner = ModelRunner(qwen.llm)
|
||||
|
||||
# batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7, 14, 13, 1, 12, 13]], dtype=torch.int64)
|
||||
# sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
|
||||
# print(sorted_logits.detach().cpu().numpy())
|
||||
# print(sorted_indices.detach().cpu().numpy())
|
||||
|
||||
val = ds.InitValDataset(conf).dataset
|
||||
md = val.meaning_dataset
|
||||
map = md.get_meaning_map()
|
||||
item = md.get_token(0)
|
||||
|
||||
node = map.get_nodetree(md.get_meaning(0))
|
||||
# node.print()
|
||||
# seq:844
|
||||
# seq:849
|
||||
# seq:991
|
||||
# seq:995
|
||||
|
||||
node = map.get_nodetree(995)
|
||||
item, l, rank_idx, rank_all = map.get_sequence(995)
|
||||
print("len of seq:" + str(len(item)))
|
||||
|
||||
for i in range(1, len(item)):
|
||||
itemm = [item[:i]]
|
||||
|
|
|
@ -15,7 +15,7 @@ import dataset.dataset as ds
|
|||
|
||||
if __name__ == "__main__":
|
||||
|
||||
checkpoint_path = "log/bigger/version_0/checkpoints/epoch=72-step=360328.ckpt"
|
||||
checkpoint_path = "log/bigger/version_8/checkpoints/epoch=14-step=67455.ckpt"
|
||||
|
||||
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
||||
qwen.eval()
|
||||
|
@ -25,6 +25,7 @@ if __name__ == "__main__":
|
|||
runner = ModelRunner(qwen.llm)
|
||||
|
||||
def DumpQK(query, key, causal_mask, index):
|
||||
global relation_table
|
||||
size = query.shape[2]
|
||||
scale_factor = 1 / math.sqrt(query.size(-1))
|
||||
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
||||
|
@ -35,28 +36,44 @@ if __name__ == "__main__":
|
|||
attn_weight = attn_weight * attn_mask
|
||||
qk = attn_weight[0]
|
||||
prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(index) + ".png"
|
||||
qk = qk.cpu()
|
||||
qk = torch.cat((qk, relation_table.unsqueeze(0)), dim=0)
|
||||
show.DumpTensorToImage(qk, prePath, GridValue=255)
|
||||
# qk_seq.append(qk)
|
||||
# qk_index = size
|
||||
|
||||
qwen.llm.hook_attention = DumpQK
|
||||
|
||||
|
||||
val = ds.InitValDataset(conf).dataset
|
||||
md = val.meaning_dataset
|
||||
map = md.get_meaning_map()
|
||||
item = md.get_token(0)
|
||||
|
||||
node = map.get_nodetree(md.get_meaning(0))
|
||||
# node.print()
|
||||
# seq:844
|
||||
# seq:849
|
||||
# seq:991
|
||||
# seq:995
|
||||
meaning = 995
|
||||
node = map.get_nodetree(meaning)
|
||||
current_to_common, common_to_current = map.get_level_change(meaning)
|
||||
|
||||
node.print()
|
||||
print(current_to_common)
|
||||
print(common_to_current)
|
||||
relation_table = map.get_relation_table(meaning)
|
||||
# prePath = "./temp/" + "q@k_seq_" + "_layer_" + ".png"
|
||||
# show.DumpTensorToImage(relation_table, prePath, GridValue=255)
|
||||
relation_table = torch.tensor(relation_table)
|
||||
|
||||
item, level, rank_idx, rank_all = map.get_sequence(meaning)
|
||||
print(item)
|
||||
print(level)
|
||||
print(rank_idx)
|
||||
print(rank_all)
|
||||
print("len of seq:" + str(len(item)))
|
||||
|
||||
batch = torch.tensor([item], dtype=torch.int64)
|
||||
sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
|
||||
next_token = sorted_indices.detach().cpu().numpy()[0][0]
|
||||
node.print()
|
||||
|
||||
|
||||
|
||||
|
||||
# batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7]], dtype=torch.int64)
|
||||
# sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
|
||||
|
|
|
@ -25,13 +25,13 @@ if __name__ == "__main__":
|
|||
loader = train_dataloader.dataset
|
||||
|
||||
map = loader.meaning_dataset.get_meaning_map()
|
||||
trees = {}
|
||||
seqs = {}
|
||||
for batch in loader:
|
||||
for m in batch["meaning"]:
|
||||
trees[m] = map.get_tree(m)
|
||||
seqs[m] = map.get_sequence(m)
|
||||
while True:
|
||||
m = int(input("input meaning: "))
|
||||
total = 0
|
||||
for tree in trees.values():
|
||||
total = total + tree.count(m)
|
||||
for seq in seqs.values():
|
||||
total = total + seq.count(m)
|
||||
print(f"meaning of {m} count as {total}")
|
||||
|
|
16
wit/train.py
16
wit/train.py
|
@ -20,7 +20,7 @@ if __name__ == "__main__":
|
|||
conf.learning_rate = 0.001
|
||||
conf.use_tril_attention_mask = None
|
||||
conf.precision = "16-mixed" # "precision:bf16-mixed,16-mixed,32-true"
|
||||
conf.train_batch_size = 16
|
||||
conf.train_batch_size = 32
|
||||
conf.val_batch_size = 2
|
||||
conf.num_proc = 8
|
||||
conf.max_epochs = 1000
|
||||
|
@ -29,18 +29,18 @@ if __name__ == "__main__":
|
|||
conf.seed = 42
|
||||
conf.dataloader_works = 2
|
||||
|
||||
conf.dataset.meaning.level_ratio = 5
|
||||
conf.dataset.meaning.level = 2
|
||||
conf.dataset.meaning.dataset_level = 5
|
||||
conf.dataset.meaning.start = 800
|
||||
conf.dataset.meaning.size = 200000
|
||||
conf.dataset.meaning.min_subitem = 2
|
||||
conf.dataset.meaning.max_subitem = 4
|
||||
conf.dataset.meaning.val_mask_level = [0, 1, 2]
|
||||
conf.dataset.meaning.val_mask_idx = [0, 0, -1]
|
||||
|
||||
config.vocab_size = 32
|
||||
config.hidden_size = 128 # 128 1024 2048 32
|
||||
config.intermediate_size = 256
|
||||
config.num_hidden_layers = 3 # 6 12 24 3
|
||||
config.num_attention_heads = 8 # 8 8 16
|
||||
config.hidden_size = 256 # 128 1024 2048 32
|
||||
config.intermediate_size = 512
|
||||
config.num_hidden_layers = 4 # 6 12 24 3
|
||||
config.num_attention_heads = 4 # 8 8 16
|
||||
|
||||
torch.manual_seed(conf.seed)
|
||||
np.random.seed(conf.seed)
|
||||
|
|
Loading…
Reference in New Issue