Add stride support in meaning_dataset.
This commit is contained in:
parent
d7191003e0
commit
9c75b8920d
|
@ -11,7 +11,15 @@ from node_tree import NodeTree
|
||||||
|
|
||||||
class MeaningMap:
|
class MeaningMap:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, size=1048576, vocab_size=4096, max_subitem=10, min_subitem=1, with_parent=False, use_cache=True, seed=42
|
self,
|
||||||
|
size=1048576,
|
||||||
|
vocab_size=4096,
|
||||||
|
max_subitem=10,
|
||||||
|
min_subitem=1,
|
||||||
|
stride=1,
|
||||||
|
with_tree=False,
|
||||||
|
use_cache=True,
|
||||||
|
seed=42,
|
||||||
):
|
):
|
||||||
assert size > 0 and vocab_size > 0 and max_subitem > 0 and min_subitem > 0, "Invalid input"
|
assert size > 0 and vocab_size > 0 and max_subitem > 0 and min_subitem > 0, "Invalid input"
|
||||||
assert min_subitem <= max_subitem, "Invalid input"
|
assert min_subitem <= max_subitem, "Invalid input"
|
||||||
|
@ -20,7 +28,8 @@ class MeaningMap:
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.max_subitem = max_subitem
|
self.max_subitem = max_subitem
|
||||||
self.min_subitem = min_subitem
|
self.min_subitem = min_subitem
|
||||||
self.with_parent = with_parent
|
self.with_tree = with_tree
|
||||||
|
self.stride = stride
|
||||||
datastep = 0x8000000
|
datastep = 0x8000000
|
||||||
|
|
||||||
path = "./data/"
|
path = "./data/"
|
||||||
|
@ -93,22 +102,27 @@ class MeaningMap:
|
||||||
|
|
||||||
index = 0
|
index = 0
|
||||||
for i in range(self.vocab_size):
|
for i in range(self.vocab_size):
|
||||||
ms_data[i] = i
|
ms_data[index] = i
|
||||||
|
ms_level[index] = 0
|
||||||
|
ms_rank_idx[index] = 0xFFFFFFF
|
||||||
|
ms_rank_all[index] = 0xFFFFFFF
|
||||||
|
for ind in range(index + 1, index + stride):
|
||||||
|
ms_data[ind] = -1
|
||||||
|
ms_level[ind] = 511
|
||||||
|
ms_rank_idx[ind] = 0xFFFFFFF
|
||||||
|
ms_rank_all[ind] = 0xFFFFFFF
|
||||||
ms_start[i] = index
|
ms_start[i] = index
|
||||||
ms_end[i] = index + 1
|
ms_end[i] = index + stride
|
||||||
ms_len[i] = 1
|
ms_len[i] = stride
|
||||||
ms_level[i] = 0
|
|
||||||
ms_rank_idx[i] = 0xFFFFFFF
|
|
||||||
ms_rank_all[i] = 0xFFFFFFF
|
|
||||||
ms_height[i] = 0
|
ms_height[i] = 0
|
||||||
ms_weight[i] = 1
|
ms_weight[i] = 1
|
||||||
index = index + 1
|
index = index + stride
|
||||||
|
|
||||||
for i in range(self.vocab_size, size):
|
for i in range(self.vocab_size, size):
|
||||||
m = map[i] # 当前meaning的拆分的分支
|
m = map[i] # 当前meaning的拆分的分支
|
||||||
m = m[m >= 0] # donot cut off the map such as [0]
|
m = m[m >= 0] # donot cut off the map such as [0]
|
||||||
m_len = len(m) # 当前meaning的拆分的分支个数
|
m_len = len(m) # 当前meaning的拆分的分支个数
|
||||||
if self.with_parent:
|
if self.with_tree:
|
||||||
m_len = m_len + 1
|
m_len = m_len + 1
|
||||||
m_list = m.tolist()
|
m_list = m.tolist()
|
||||||
assert m_list, "map list can not be empty list"
|
assert m_list, "map list can not be empty list"
|
||||||
|
@ -119,7 +133,7 @@ class MeaningMap:
|
||||||
[np.ones(l, dtype=np.uint32) * i for i, l in enumerate(ms_end[m_list] - ms_start[m_list])]
|
[np.ones(l, dtype=np.uint32) * i for i, l in enumerate(ms_end[m_list] - ms_start[m_list])]
|
||||||
)
|
)
|
||||||
len_ma = len(idx)
|
len_ma = len(idx)
|
||||||
if self.with_parent:
|
if self.with_tree:
|
||||||
len_ma = len_ma + 1
|
len_ma = len_ma + 1
|
||||||
|
|
||||||
end = index + len_ma
|
end = index + len_ma
|
||||||
|
@ -131,22 +145,22 @@ class MeaningMap:
|
||||||
|
|
||||||
# 拼接当前meaning的所有token到data数据结构里面
|
# 拼接当前meaning的所有token到data数据结构里面
|
||||||
new_data = ms_data[idx]
|
new_data = ms_data[idx]
|
||||||
if self.with_parent:
|
if self.with_tree:
|
||||||
new_data = np.concatenate([new_data, np.array([i])])
|
new_data = np.concatenate([new_data, np.array([i])])
|
||||||
ms_data[index:end] = new_data
|
ms_data[index:end] = new_data
|
||||||
# 处理level
|
# 处理level
|
||||||
new_level = ms_level[idx] + 1
|
new_level = ms_level[idx] + 1
|
||||||
if self.with_parent:
|
if self.with_tree:
|
||||||
new_level = np.concatenate([new_level, np.array([256])])
|
new_level = np.concatenate([new_level, np.array([255 + 1])])
|
||||||
ms_level[index:end] = new_level
|
ms_level[index:end] = new_level
|
||||||
# 处理rank_idx
|
# 处理rank_idx
|
||||||
new_idx = (ms_rank_idx[idx] * 16 + idxidx).astype(np.uint32)
|
new_idx = (ms_rank_idx[idx] * 16 + idxidx).astype(np.uint32)
|
||||||
if self.with_parent:
|
if self.with_tree:
|
||||||
new_idx = np.concatenate([new_idx, np.array([0xFFFFFFF0 + m_len - 1])])
|
new_idx = np.concatenate([new_idx, np.array([0xFFFFFFF0 + m_len - 1])])
|
||||||
ms_rank_idx[index:end] = new_idx
|
ms_rank_idx[index:end] = new_idx
|
||||||
# 处理rank_all
|
# 处理rank_all
|
||||||
new_all = (ms_rank_all[idx] * 16 + m_len).astype(np.uint32)
|
new_all = (ms_rank_all[idx] * 16 + m_len).astype(np.uint32)
|
||||||
if self.with_parent:
|
if self.with_tree:
|
||||||
new_all = np.concatenate([new_all, np.array([0xFFFFFFF0 + m_len])])
|
new_all = np.concatenate([new_all, np.array([0xFFFFFFF0 + m_len])])
|
||||||
ms_rank_all[index:end] = new_all
|
ms_rank_all[index:end] = new_all
|
||||||
|
|
||||||
|
@ -295,7 +309,8 @@ class MeaningDataset(Dataset):
|
||||||
min_subitem=1,
|
min_subitem=1,
|
||||||
min_seq_len=2,
|
min_seq_len=2,
|
||||||
seed=42,
|
seed=42,
|
||||||
with_parent=False,
|
stride=1,
|
||||||
|
with_tree=False,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
):
|
):
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
|
@ -304,7 +319,8 @@ class MeaningDataset(Dataset):
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.max_subitem = max_subitem
|
self.max_subitem = max_subitem
|
||||||
self.min_subitem = min_subitem
|
self.min_subitem = min_subitem
|
||||||
self.with_parent = with_parent
|
self.stride = stride
|
||||||
|
self.with_tree = with_tree
|
||||||
self.use_cache = use_cache
|
self.use_cache = use_cache
|
||||||
self.min_seq_len = min_seq_len
|
self.min_seq_len = min_seq_len
|
||||||
print("Build MeaningDataset from MeaningMap.")
|
print("Build MeaningDataset from MeaningMap.")
|
||||||
|
@ -376,7 +392,7 @@ class MeaningDataset(Dataset):
|
||||||
|
|
||||||
def get_meaning_map(self):
|
def get_meaning_map(self):
|
||||||
return MeaningMap(
|
return MeaningMap(
|
||||||
self.end, self.vocab_size, self.max_subitem, self.min_subitem, self.with_parent, self.use_cache
|
self.end, self.vocab_size, self.max_subitem, self.min_subitem, self.stride, self.with_tree, self.use_cache
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_mask(self, level=None, idx=None):
|
def set_mask(self, level=None, idx=None):
|
||||||
|
@ -502,7 +518,16 @@ class BatchGroupMeaningDataloader(Dataset):
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
md = MeaningDataset(100000, 115200, vocab_size=32, size=1024, with_parent=True, use_cache=False)
|
md = MeaningDataset(100000, 115200, vocab_size=32, size=1024, stride=2, with_tree=False, use_cache=False)
|
||||||
|
item = md.__getitem__(920)
|
||||||
|
mm = md.get_meaning_map()
|
||||||
|
mm.get_nodetree(int(item["meaning"][0])).print()
|
||||||
|
item_seq = item["input_ids"].numpy().tolist()
|
||||||
|
item_mask = item["val_mask"].numpy().tolist()
|
||||||
|
print(item_seq)
|
||||||
|
print(item_mask)
|
||||||
|
|
||||||
|
md = MeaningDataset(100000, 115200, vocab_size=32, size=1024, stride=1, with_tree=True, use_cache=False)
|
||||||
train, val = md.split(0.95)
|
train, val = md.split(0.95)
|
||||||
item = md.__getitem__(920)
|
item = md.__getitem__(920)
|
||||||
|
|
||||||
|
|
|
@ -18,8 +18,8 @@ meaning数据集是一个模仿自然语言,以及抽象表达的数据集。
|
||||||
12. meaning_height 当前meaning的总高度
|
12. meaning_height 当前meaning的总高度
|
||||||
13. meaning_weight 当前meaning的总宽度
|
13. meaning_weight 当前meaning的总宽度
|
||||||
14. val_mask_level val_mask_idx: 表示用于训练的token的mask,val_mask_level=[0, 1, 2] val_mask_idx=[0, 0, -1]表示只有是第0层第0个,而且是第1层第0个,第2层最后一个的token,才参与训练
|
14. val_mask_level val_mask_idx: 表示用于训练的token的mask,val_mask_level=[0, 1, 2] val_mask_idx=[0, 0, -1]表示只有是第0层第0个,而且是第1层第0个,第2层最后一个的token,才参与训练
|
||||||
15. with_parent 表示每个meaning的拆分都以当前meaning编号结束,sequence会插入很多枝节点,而不仅仅是叶节点
|
15. with_tree 表示每个meaning的拆分都以当前meaning编号结束,sequence会插入很多枝节点,而不仅仅是叶节点
|
||||||
16. ms_level 大于255表示这个token是parent(level 0用256表示),生成val mask的时候是False
|
16. ms_level 大于255表示这个token是tree(level 0用255表示),大于511表示这个token是stride(level 0用511表示),生成val mask的时候是False
|
||||||
17. ms_map 表示每个meaning拆解的sub meaning
|
17. ms_map 表示每个meaning拆解的sub meaning
|
||||||
18. index must < 15,level must < 8
|
18. index must < 15,level must < 8
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue