Add parent support in meaning_dataset.
This commit is contained in:
parent
8882073978
commit
d7191003e0
|
@ -6,14 +6,13 @@ from typing import Dict, Tuple
|
||||||
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
|
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.data import BatchSampler
|
from torch.utils.data import BatchSampler
|
||||||
from dataset.node_tree import NodeTree
|
from node_tree import NodeTree
|
||||||
|
|
||||||
# import warnings
|
|
||||||
# warnings.filterwarnings("ignore", ".*does not have many workers.*")
|
|
||||||
|
|
||||||
|
|
||||||
class MeaningMap:
|
class MeaningMap:
|
||||||
def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, min_subitem=1, use_cache=True, seed=42):
|
def __init__(
|
||||||
|
self, size=1048576, vocab_size=4096, max_subitem=10, min_subitem=1, with_parent=False, use_cache=True, seed=42
|
||||||
|
):
|
||||||
assert size > 0 and vocab_size > 0 and max_subitem > 0 and min_subitem > 0, "Invalid input"
|
assert size > 0 and vocab_size > 0 and max_subitem > 0 and min_subitem > 0, "Invalid input"
|
||||||
assert min_subitem <= max_subitem, "Invalid input"
|
assert min_subitem <= max_subitem, "Invalid input"
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
|
@ -21,6 +20,7 @@ class MeaningMap:
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.max_subitem = max_subitem
|
self.max_subitem = max_subitem
|
||||||
self.min_subitem = min_subitem
|
self.min_subitem = min_subitem
|
||||||
|
self.with_parent = with_parent
|
||||||
datastep = 0x8000000
|
datastep = 0x8000000
|
||||||
|
|
||||||
path = "./data/"
|
path = "./data/"
|
||||||
|
@ -108,6 +108,8 @@ class MeaningMap:
|
||||||
m = map[i] # 当前meaning的拆分的分支
|
m = map[i] # 当前meaning的拆分的分支
|
||||||
m = m[m >= 0] # donot cut off the map such as [0]
|
m = m[m >= 0] # donot cut off the map such as [0]
|
||||||
m_len = len(m) # 当前meaning的拆分的分支个数
|
m_len = len(m) # 当前meaning的拆分的分支个数
|
||||||
|
if self.with_parent:
|
||||||
|
m_len = m_len + 1
|
||||||
m_list = m.tolist()
|
m_list = m.tolist()
|
||||||
assert m_list, "map list can not be empty list"
|
assert m_list, "map list can not be empty list"
|
||||||
|
|
||||||
|
@ -117,6 +119,8 @@ class MeaningMap:
|
||||||
[np.ones(l, dtype=np.uint32) * i for i, l in enumerate(ms_end[m_list] - ms_start[m_list])]
|
[np.ones(l, dtype=np.uint32) * i for i, l in enumerate(ms_end[m_list] - ms_start[m_list])]
|
||||||
)
|
)
|
||||||
len_ma = len(idx)
|
len_ma = len(idx)
|
||||||
|
if self.with_parent:
|
||||||
|
len_ma = len_ma + 1
|
||||||
|
|
||||||
end = index + len_ma
|
end = index + len_ma
|
||||||
if ms_data.size < end: # 超过存储数据结构的大小,扩展一个datastep容量
|
if ms_data.size < end: # 超过存储数据结构的大小,扩展一个datastep容量
|
||||||
|
@ -125,10 +129,26 @@ class MeaningMap:
|
||||||
ms_rank_idx = np.concatenate([ms_rank_idx, np.zeros((datastep), dtype=np.uint32)])
|
ms_rank_idx = np.concatenate([ms_rank_idx, np.zeros((datastep), dtype=np.uint32)])
|
||||||
ms_rank_all = np.concatenate([ms_rank_all, np.zeros((datastep), dtype=np.uint32)])
|
ms_rank_all = np.concatenate([ms_rank_all, np.zeros((datastep), dtype=np.uint32)])
|
||||||
|
|
||||||
ms_data[index:end] = ms_data[idx] # 拼接当前meaning的所有token到data数据结构里面
|
# 拼接当前meaning的所有token到data数据结构里面
|
||||||
ms_level[index:end] = ms_level[idx] + 1 # 处理level
|
new_data = ms_data[idx]
|
||||||
ms_rank_idx[index:end] = (ms_rank_idx[idx] * 16 + idxidx).astype(np.uint32) # 处理rank_idx
|
if self.with_parent:
|
||||||
ms_rank_all[index:end] = (ms_rank_all[idx] * 16 + m_len).astype(np.uint32) # 处理rank_all
|
new_data = np.concatenate([new_data, np.array([i])])
|
||||||
|
ms_data[index:end] = new_data
|
||||||
|
# 处理level
|
||||||
|
new_level = ms_level[idx] + 1
|
||||||
|
if self.with_parent:
|
||||||
|
new_level = np.concatenate([new_level, np.array([256])])
|
||||||
|
ms_level[index:end] = new_level
|
||||||
|
# 处理rank_idx
|
||||||
|
new_idx = (ms_rank_idx[idx] * 16 + idxidx).astype(np.uint32)
|
||||||
|
if self.with_parent:
|
||||||
|
new_idx = np.concatenate([new_idx, np.array([0xFFFFFFF0 + m_len - 1])])
|
||||||
|
ms_rank_idx[index:end] = new_idx
|
||||||
|
# 处理rank_all
|
||||||
|
new_all = (ms_rank_all[idx] * 16 + m_len).astype(np.uint32)
|
||||||
|
if self.with_parent:
|
||||||
|
new_all = np.concatenate([new_all, np.array([0xFFFFFFF0 + m_len])])
|
||||||
|
ms_rank_all[index:end] = new_all
|
||||||
|
|
||||||
ms_start[i] = index
|
ms_start[i] = index
|
||||||
ms_end[i] = end
|
ms_end[i] = end
|
||||||
|
@ -275,6 +295,7 @@ class MeaningDataset(Dataset):
|
||||||
min_subitem=1,
|
min_subitem=1,
|
||||||
min_seq_len=2,
|
min_seq_len=2,
|
||||||
seed=42,
|
seed=42,
|
||||||
|
with_parent=False,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
):
|
):
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
|
@ -283,6 +304,7 @@ class MeaningDataset(Dataset):
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.max_subitem = max_subitem
|
self.max_subitem = max_subitem
|
||||||
self.min_subitem = min_subitem
|
self.min_subitem = min_subitem
|
||||||
|
self.with_parent = with_parent
|
||||||
self.use_cache = use_cache
|
self.use_cache = use_cache
|
||||||
self.min_seq_len = min_seq_len
|
self.min_seq_len = min_seq_len
|
||||||
print("Build MeaningDataset from MeaningMap.")
|
print("Build MeaningDataset from MeaningMap.")
|
||||||
|
@ -310,9 +332,11 @@ class MeaningDataset(Dataset):
|
||||||
self.seq_meaning.append(m)
|
self.seq_meaning.append(m)
|
||||||
seq_len.append(len(d))
|
seq_len.append(len(d))
|
||||||
|
|
||||||
|
origin_l = l.copy()
|
||||||
|
origin_l[l >= 255] = l[l >= 255] - 255
|
||||||
dm = np.ones(i.shape, dtype=np.uint32)
|
dm = np.ones(i.shape, dtype=np.uint32)
|
||||||
dm = ((dm * 0xFFFFFFFF) << (l * 4)).astype(np.uint32)
|
dm = ((dm * 0xFFFFFFFF) << (origin_l * 4)).astype(np.uint32)
|
||||||
shift = (8 - l) * 4
|
shift = (8 - origin_l) * 4
|
||||||
rank_idx = (i & 0xF) << 28
|
rank_idx = (i & 0xF) << 28
|
||||||
rank_idx = rank_idx + ((i & 0xF0) << 20)
|
rank_idx = rank_idx + ((i & 0xF0) << 20)
|
||||||
rank_idx = rank_idx + ((i & 0xF00) << 12)
|
rank_idx = rank_idx + ((i & 0xF00) << 12)
|
||||||
|
@ -351,14 +375,16 @@ class MeaningDataset(Dataset):
|
||||||
return len(self.seq)
|
return len(self.seq)
|
||||||
|
|
||||||
def get_meaning_map(self):
|
def get_meaning_map(self):
|
||||||
return MeaningMap(self.end, self.vocab_size, self.max_subitem, self.min_subitem, self.use_cache)
|
return MeaningMap(
|
||||||
|
self.end, self.vocab_size, self.max_subitem, self.min_subitem, self.with_parent, self.use_cache
|
||||||
|
)
|
||||||
|
|
||||||
def set_mask(self, level=None, idx=None):
|
def set_mask(self, level=None, idx=None):
|
||||||
if self.val_mask_level is not None and self.val_mask_idx is not None:
|
if level is not None and idx is not None:
|
||||||
assert len(self.val_mask_level) > 0, "len must > 0"
|
assert len(level) > 0, "len must > 0"
|
||||||
assert len(self.val_mask_level) == len(self.val_mask_idx), "mask level and mask index must be same length"
|
assert len(level) == len(idx), "mask level and mask index must be same length"
|
||||||
assert isinstance(self.val_mask_level, list), "mask level must be list"
|
assert isinstance(level, list), "mask level must be list"
|
||||||
assert isinstance(self.val_mask_idx, list), "mask index must be list"
|
assert isinstance(idx, list), "mask index must be list"
|
||||||
self.val_mask_level = level
|
self.val_mask_level = level
|
||||||
self.val_mask_idx = idx
|
self.val_mask_idx = idx
|
||||||
|
|
||||||
|
@ -403,7 +429,7 @@ class MeaningDataset(Dataset):
|
||||||
# assert level < 8, "level must < 8"
|
# assert level < 8, "level must < 8"
|
||||||
rank_idx = (self.rank_idx[idx] >> (4 * level)).astype(np.int32) & 0xF
|
rank_idx = (self.rank_idx[idx] >> (4 * level)).astype(np.int32) & 0xF
|
||||||
rank_all = (self.rank_all[idx] >> (4 * level)).astype(np.int32) & 0xF
|
rank_all = (self.rank_all[idx] >> (4 * level)).astype(np.int32) & 0xF
|
||||||
return rank_idx == (rank_all + index if index < 0 else index)
|
return (rank_idx == (rank_all + index if index < 0 else index)) & (self.level[idx] < 255)
|
||||||
|
|
||||||
def get_seq_mask_tensor(self, idx_list):
|
def get_seq_mask_tensor(self, idx_list):
|
||||||
if self.val_mask_level is not None and self.val_mask_idx is not None:
|
if self.val_mask_level is not None and self.val_mask_idx is not None:
|
||||||
|
@ -418,7 +444,7 @@ class MeaningDataset(Dataset):
|
||||||
)
|
)
|
||||||
return mask
|
return mask
|
||||||
else:
|
else:
|
||||||
return None
|
return torch.tensor(np.stack([(self.level[idx] < 255) for idx in idx_list], axis=0))
|
||||||
|
|
||||||
|
|
||||||
class BatchGroupMeaningDataloader(Dataset):
|
class BatchGroupMeaningDataloader(Dataset):
|
||||||
|
@ -476,11 +502,19 @@ class BatchGroupMeaningDataloader(Dataset):
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=False)
|
md = MeaningDataset(100000, 115200, vocab_size=32, size=1024, with_parent=True, use_cache=False)
|
||||||
md.set_mask([1], [-1])
|
|
||||||
train, val = md.split(0.95)
|
train, val = md.split(0.95)
|
||||||
fdaf = md.__getitem__(920)
|
item = md.__getitem__(920)
|
||||||
print(md.print_tree(920))
|
|
||||||
|
mm = md.get_meaning_map()
|
||||||
|
mm.get_nodetree(int(item["meaning"][0])).print()
|
||||||
|
item_seq = item["input_ids"].numpy().tolist()
|
||||||
|
item_mask = item["val_mask"].numpy().tolist()
|
||||||
|
print(item_seq)
|
||||||
|
print(item_mask)
|
||||||
|
|
||||||
|
md.set_mask([1], [-1])
|
||||||
|
|
||||||
print(md.rank_idx[920])
|
print(md.rank_idx[920])
|
||||||
print(md.rank_all[920])
|
print(md.rank_all[920])
|
||||||
mask = md.get_seq_mask(920, 0, -1)
|
mask = md.get_seq_mask(920, 0, -1)
|
||||||
|
@ -491,72 +525,6 @@ if __name__ == "__main__":
|
||||||
print(mask)
|
print(mask)
|
||||||
mask = md.get_seq_mask(920, 1, 1)
|
mask = md.get_seq_mask(920, 1, 1)
|
||||||
print(mask)
|
print(mask)
|
||||||
assert all(
|
|
||||||
np.equal(
|
|
||||||
mask[0:57],
|
|
||||||
np.array(
|
|
||||||
[
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
True,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
True,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
]
|
|
||||||
),
|
|
||||||
)
|
|
||||||
), "False"
|
|
||||||
|
|
||||||
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, min_subitem=2, use_cache=False)
|
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, min_subitem=2, use_cache=False)
|
||||||
md.set_mask([0, 1], [0, -1])
|
md.set_mask([0, 1], [0, -1])
|
||||||
|
|
|
@ -18,6 +18,10 @@ meaning数据集是一个模仿自然语言,以及抽象表达的数据集。
|
||||||
12. meaning_height 当前meaning的总高度
|
12. meaning_height 当前meaning的总高度
|
||||||
13. meaning_weight 当前meaning的总宽度
|
13. meaning_weight 当前meaning的总宽度
|
||||||
14. val_mask_level val_mask_idx: 表示用于训练的token的mask,val_mask_level=[0, 1, 2] val_mask_idx=[0, 0, -1]表示只有是第0层第0个,而且是第1层第0个,第2层最后一个的token,才参与训练
|
14. val_mask_level val_mask_idx: 表示用于训练的token的mask,val_mask_level=[0, 1, 2] val_mask_idx=[0, 0, -1]表示只有是第0层第0个,而且是第1层第0个,第2层最后一个的token,才参与训练
|
||||||
|
15. with_parent 表示每个meaning的拆分都以当前meaning编号结束,sequence会插入很多枝节点,而不仅仅是叶节点
|
||||||
|
16. ms_level 大于255表示这个token是parent(level 0用256表示),生成val mask的时候是False
|
||||||
|
17. ms_map 表示每个meaning拆解的sub meaning
|
||||||
|
18. index must < 15,level must < 8
|
||||||
|
|
||||||
## code
|
## code
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue