Speedup dataset generate.
This commit is contained in:
parent
33d1e22655
commit
9d3b9a210a
|
@ -40,7 +40,7 @@ class MeaningMap:
|
|||
else:
|
||||
print("Disk cache miss, build new one.")
|
||||
|
||||
map = np.empty((size, max_subitem), dtype=np.uint32)
|
||||
map = np.empty((size, max_subitem), dtype=np.int32)
|
||||
|
||||
index = np.arange(0, size)
|
||||
map = np.random.random((size, max_subitem))
|
||||
|
@ -53,13 +53,12 @@ class MeaningMap:
|
|||
|
||||
item_sum = map.sum(axis=1)
|
||||
scale = (index / item_sum).reshape(-1, 1).repeat(max_subitem, axis=1)
|
||||
map = map * scale
|
||||
map = (map * scale).astype(np.int32)
|
||||
|
||||
map[mask_zero] = 0
|
||||
map[mask_zero] = -1
|
||||
|
||||
map[:vocab_size, 0] = np.arange(0, vocab_size)
|
||||
map[:vocab_size, 1:] = 0
|
||||
map = map.astype(np.uint32)
|
||||
map[:vocab_size, 1:] = -1
|
||||
|
||||
ms_data = [] # meaning sequence
|
||||
ms_level = [] # meaning level, vocab's level is 0
|
||||
|
@ -70,33 +69,42 @@ class MeaningMap:
|
|||
ms_weight = [] # meaning tree weight
|
||||
index = 0
|
||||
for i in range(self.vocab_size):
|
||||
ms_data.append([i])
|
||||
ms_level.append([0])
|
||||
ms_idx.append([0])
|
||||
ms_data.append(np.array([i]))
|
||||
ms_level.append(np.array([0]))
|
||||
ms_idx.append(np.array([0]))
|
||||
ms_start.append(index)
|
||||
ms_len.append(1)
|
||||
index = index + 1
|
||||
ms_height.append(0)
|
||||
ms_weight.append(1)
|
||||
index = index + 1
|
||||
|
||||
for i in range(self.vocab_size, size):
|
||||
m = map[i]
|
||||
m = m[m > 0]
|
||||
ma = []
|
||||
ml = []
|
||||
mi = []
|
||||
for i, newm in enumerate(m.tolist()):
|
||||
ma = ma + ms_data[newm]
|
||||
ml = ml + [x + 1 for x in ms_level[newm]]
|
||||
mi = mi + ([0xFFFFFFF0 + i] if newm < self.vocab_size else [n * 16 + i for n in ms_idx[newm]])
|
||||
m = m[m >= 0]
|
||||
|
||||
m_list = m.tolist()
|
||||
assert m_list, "map list can not be empty list"
|
||||
|
||||
ma = np.concatenate([ms_data[newm] for newm in m_list])
|
||||
ml = np.concatenate([ms_level[newm] + 1 for newm in m_list])
|
||||
mi = np.concatenate(
|
||||
[
|
||||
([0xFFFFFFF0 + i] if newm < self.vocab_size else ms_idx[newm] * 16 + i)
|
||||
for i, newm in enumerate(m_list)
|
||||
]
|
||||
)
|
||||
ml = ml[ma > 0]
|
||||
mi = mi[ma > 0]
|
||||
ma = ma[ma > 0]
|
||||
|
||||
ms_data.append(ma)
|
||||
ms_start.append(index)
|
||||
ms_len.append(len(ma))
|
||||
ms_level.append(ml)
|
||||
ms_idx.append(mi)
|
||||
ms_start.append(index)
|
||||
ms_len.append(len(ma))
|
||||
ms_height.append(max([ms_height[sub_m] for sub_m in m_list]) + 1)
|
||||
ms_weight.append(sum(ms_weight[sub_m] for sub_m in m_list))
|
||||
index = index + len(ma)
|
||||
ms_height.append(max([-1] + [ms_height[sub_m] for sub_m in m.tolist()]) + 1)
|
||||
ms_weight.append(sum(ms_weight[sub_m] for sub_m in m.tolist()))
|
||||
|
||||
# offsets = [0, 0, 4, 8, 12, 16, 20, 24, 28]
|
||||
# for idxmi, mi in enumerate(ms_idx):
|
||||
|
@ -126,11 +134,10 @@ class MeaningMap:
|
|||
)
|
||||
ms_idx = ((ms_idx >> ((8 - ms_level) * 4)) + d).astype(np.uint32)
|
||||
|
||||
ms_start = np.array(ms_start).astype(np.uint32)
|
||||
ms_height = np.array(ms_height).astype(np.uint32)
|
||||
ms_weight = np.array(ms_weight).astype(np.uint32)
|
||||
ms_len = np.array(ms_len).astype(np.uint32)
|
||||
ms_map = map.astype(np.uint32)
|
||||
ms_start = np.array(ms_start).astype(np.int32)
|
||||
ms_height = np.array(ms_height).astype(np.int32)
|
||||
ms_weight = np.array(ms_weight).astype(np.int32)
|
||||
ms_len = np.array(ms_len).astype(np.int32)
|
||||
|
||||
slhwm = np.concatenate(
|
||||
(
|
||||
|
@ -138,7 +145,7 @@ class MeaningMap:
|
|||
ms_len.reshape((-1, 1)),
|
||||
ms_height.reshape((-1, 1)),
|
||||
ms_weight.reshape((-1, 1)),
|
||||
ms_map,
|
||||
map,
|
||||
),
|
||||
axis=1,
|
||||
)
|
||||
|
@ -383,7 +390,7 @@ class BatchGroupMeaningDataloader(Dataset):
|
|||
|
||||
if __name__ == "__main__":
|
||||
|
||||
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=False)
|
||||
md = MeaningDataset(100000, 115200, vocab_size=128, size=1024, use_cache=False)
|
||||
train, val = md.split(0.95)
|
||||
fdaf = md.__getitem__(920)
|
||||
print(md.print_tree(920))
|
||||
|
|
Loading…
Reference in New Issue