Fix rank index and all calculater when tree mode enable.
This commit is contained in:
parent
47ba18a4f9
commit
a983a2b6e6
|
@ -11,8 +11,8 @@ meaning数据集是一个模仿自然语言,以及抽象表达的数据集。
|
||||||
5. meaning通过一层层的向低编号的meaning进行组合替换,最终形成一个最底层是token的树形数据
|
5. meaning通过一层层的向低编号的meaning进行组合替换,最终形成一个最底层是token的树形数据
|
||||||
6. level表示当前token相对于root meaning的距离
|
6. level表示当前token相对于root meaning的距离
|
||||||
7. rank
|
7. rank
|
||||||
8. rank_idx表示当前token在不同层的排序编号,每4位表示在一层里面的编号,低4位表示最低层级的rank_idx,高位无用的位用1填充
|
8. rank_idx表示当前token在不同层的排序编号,每4位表示在一层里面的编号,低4位表示最低层级的rank_idx,高位无用的位用1填充,tree token总是排在最后面,index递增
|
||||||
9. rank_all表示当前token所在的不同层的总的分支个数,每4位表示在一层里面的个数,低4位表示最低层级的rank_all,高位无用的位用1填充
|
9. rank_all表示当前token所在的不同层的总的分支个数,每4位表示在一层里面的个数,低4位表示最低层级的rank_all,高位无用的位用1填充,tree token不计入rank all的总数
|
||||||
10. tree用于存储每个meaning的拆解的数据,使用字典表达一个树形结构
|
10. tree用于存储每个meaning的拆解的数据,使用字典表达一个树形结构
|
||||||
11. get_rank_mask返回一个sequence每个token在对应level是不是对应的index,level=0:最底层,index=-1:最后一个,index=0:第一个
|
11. get_rank_mask返回一个sequence每个token在对应level是不是对应的index,level=0:最底层,index=-1:最后一个,index=0:第一个
|
||||||
12. meaning_height 当前meaning的总高度
|
12. meaning_height 当前meaning的总高度
|
||||||
|
|
|
@ -139,8 +139,6 @@ class MeaningMap:
|
||||||
m = map[i] # 当前meaning的拆分的分支
|
m = map[i] # 当前meaning的拆分的分支
|
||||||
m = m[m >= 0] # donot cut off the map such as [0]
|
m = m[m >= 0] # donot cut off the map such as [0]
|
||||||
m_len = len(m) # 当前meaning的拆分的分支个数
|
m_len = len(m) # 当前meaning的拆分的分支个数
|
||||||
if self.with_tree:
|
|
||||||
m_len = m_len + 1
|
|
||||||
m_list = m.tolist()
|
m_list = m.tolist()
|
||||||
assert m_list, "map list can not be empty list"
|
assert m_list, "map list can not be empty list"
|
||||||
|
|
||||||
|
@ -173,7 +171,7 @@ class MeaningMap:
|
||||||
# 处理rank_idx
|
# 处理rank_idx
|
||||||
new_idx = (ms_rank_idx[idx] * 16 + idxidx).astype(np.uint32)
|
new_idx = (ms_rank_idx[idx] * 16 + idxidx).astype(np.uint32)
|
||||||
if self.with_tree:
|
if self.with_tree:
|
||||||
new_idx = np.concatenate([new_idx, np.array([0xFFFFFFF0 + m_len - 1])])
|
new_idx = np.concatenate([new_idx, np.array([0xFFFFFFF0 + m_len])])
|
||||||
ms_rank_idx[index:end] = new_idx
|
ms_rank_idx[index:end] = new_idx
|
||||||
# 处理rank_all
|
# 处理rank_all
|
||||||
new_all = (ms_rank_all[idx] * 16 + m_len).astype(np.uint32)
|
new_all = (ms_rank_all[idx] * 16 + m_len).astype(np.uint32)
|
||||||
|
@ -430,10 +428,10 @@ class MeaningDataset(Dataset):
|
||||||
data = torch.tensor(np.stack(data, axis=0)).long()
|
data = torch.tensor(np.stack(data, axis=0)).long()
|
||||||
output["input_ids"] = data
|
output["input_ids"] = data
|
||||||
output["token_type_ids"] = torch.zeros(data.shape)
|
output["token_type_ids"] = torch.zeros(data.shape)
|
||||||
val_mask, stride_mask = self.get_seq_mask_tensor(idx_list)
|
val_mask, token_mask = self.get_seq_mask_tensor(idx_list)
|
||||||
output["val_mask"] = val_mask
|
output["val_mask"] = val_mask
|
||||||
labels = data.clone()
|
labels = data.clone()
|
||||||
# labels[~stride_mask] = self.vocab_size # set to vocab_size will be masked in label
|
# labels[~token_mask] = self.vocab_size # set to vocab_size will be masked in label
|
||||||
output["labels"] = labels
|
output["labels"] = labels
|
||||||
output["meaning"] = [self.seq_meaning[i] for i in idx_list]
|
output["meaning"] = [self.seq_meaning[i] for i in idx_list]
|
||||||
return output
|
return output
|
||||||
|
@ -468,14 +466,13 @@ class MeaningDataset(Dataset):
|
||||||
return rank_idx == (rank_all + index if index < 0 else index)
|
return rank_idx == (rank_all + index if index < 0 else index)
|
||||||
|
|
||||||
def get_seq_mask_tensor(self, idx_list):
|
def get_seq_mask_tensor(self, idx_list):
|
||||||
stride_mask = torch.tensor(np.stack([(self.level[idx] < 255) for idx in idx_list], axis=0))
|
token_mask = torch.tensor(np.stack([(self.level[idx] < 255) for idx in idx_list], axis=0))
|
||||||
val_mask = stride_mask.clone()
|
val_mask = token_mask.clone()
|
||||||
if self.val_mask_level and self.val_mask_idx:
|
if self.val_mask_level and self.val_mask_idx:
|
||||||
for i, l in enumerate(self.val_mask_level):
|
for i, l in enumerate(self.val_mask_level):
|
||||||
val_mask = val_mask & torch.tensor(
|
mask = [self.get_rank_mask(idx, l, self.val_mask_idx[i]) for idx in idx_list]
|
||||||
np.stack([self.get_rank_mask(idx, l, self.val_mask_idx[i]) for idx in idx_list], axis=0)
|
val_mask = val_mask & torch.tensor(np.stack(mask, axis=0))
|
||||||
)
|
return (val_mask, token_mask)
|
||||||
return (val_mask, stride_mask)
|
|
||||||
|
|
||||||
|
|
||||||
class BatchGroupMeaningDataloader(Dataset):
|
class BatchGroupMeaningDataloader(Dataset):
|
||||||
|
|
Loading…
Reference in New Issue