Speedup dataset generate.

This commit is contained in:
Colin 2024-04-07 17:03:35 +08:00
parent 33d1e22655
commit 9d3b9a210a
1 changed files with 35 additions and 28 deletions

View File

@ -40,7 +40,7 @@ class MeaningMap:
else:
print("Disk cache miss, build new one.")
map = np.empty((size, max_subitem), dtype=np.uint32)
map = np.empty((size, max_subitem), dtype=np.int32)
index = np.arange(0, size)
map = np.random.random((size, max_subitem))
@ -53,13 +53,12 @@ class MeaningMap:
item_sum = map.sum(axis=1)
scale = (index / item_sum).reshape(-1, 1).repeat(max_subitem, axis=1)
map = map * scale
map = (map * scale).astype(np.int32)
map[mask_zero] = 0
map[mask_zero] = -1
map[:vocab_size, 0] = np.arange(0, vocab_size)
map[:vocab_size, 1:] = 0
map = map.astype(np.uint32)
map[:vocab_size, 1:] = -1
ms_data = [] # meaning sequence
ms_level = [] # meaning level, vocab's level is 0
@ -70,33 +69,42 @@ class MeaningMap:
ms_weight = [] # meaning tree weight
index = 0
for i in range(self.vocab_size):
ms_data.append([i])
ms_level.append([0])
ms_idx.append([0])
ms_data.append(np.array([i]))
ms_level.append(np.array([0]))
ms_idx.append(np.array([0]))
ms_start.append(index)
ms_len.append(1)
index = index + 1
ms_height.append(0)
ms_weight.append(1)
index = index + 1
for i in range(self.vocab_size, size):
m = map[i]
m = m[m > 0]
ma = []
ml = []
mi = []
for i, newm in enumerate(m.tolist()):
ma = ma + ms_data[newm]
ml = ml + [x + 1 for x in ms_level[newm]]
mi = mi + ([0xFFFFFFF0 + i] if newm < self.vocab_size else [n * 16 + i for n in ms_idx[newm]])
m = m[m >= 0]
m_list = m.tolist()
assert m_list, "map list can not be empty list"
ma = np.concatenate([ms_data[newm] for newm in m_list])
ml = np.concatenate([ms_level[newm] + 1 for newm in m_list])
mi = np.concatenate(
[
([0xFFFFFFF0 + i] if newm < self.vocab_size else ms_idx[newm] * 16 + i)
for i, newm in enumerate(m_list)
]
)
ml = ml[ma > 0]
mi = mi[ma > 0]
ma = ma[ma > 0]
ms_data.append(ma)
ms_start.append(index)
ms_len.append(len(ma))
ms_level.append(ml)
ms_idx.append(mi)
ms_start.append(index)
ms_len.append(len(ma))
ms_height.append(max([ms_height[sub_m] for sub_m in m_list]) + 1)
ms_weight.append(sum(ms_weight[sub_m] for sub_m in m_list))
index = index + len(ma)
ms_height.append(max([-1] + [ms_height[sub_m] for sub_m in m.tolist()]) + 1)
ms_weight.append(sum(ms_weight[sub_m] for sub_m in m.tolist()))
# offsets = [0, 0, 4, 8, 12, 16, 20, 24, 28]
# for idxmi, mi in enumerate(ms_idx):
@ -126,11 +134,10 @@ class MeaningMap:
)
ms_idx = ((ms_idx >> ((8 - ms_level) * 4)) + d).astype(np.uint32)
ms_start = np.array(ms_start).astype(np.uint32)
ms_height = np.array(ms_height).astype(np.uint32)
ms_weight = np.array(ms_weight).astype(np.uint32)
ms_len = np.array(ms_len).astype(np.uint32)
ms_map = map.astype(np.uint32)
ms_start = np.array(ms_start).astype(np.int32)
ms_height = np.array(ms_height).astype(np.int32)
ms_weight = np.array(ms_weight).astype(np.int32)
ms_len = np.array(ms_len).astype(np.int32)
slhwm = np.concatenate(
(
@ -138,7 +145,7 @@ class MeaningMap:
ms_len.reshape((-1, 1)),
ms_height.reshape((-1, 1)),
ms_weight.reshape((-1, 1)),
ms_map,
map,
),
axis=1,
)
@ -383,7 +390,7 @@ class BatchGroupMeaningDataloader(Dataset):
if __name__ == "__main__":
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=False)
md = MeaningDataset(100000, 115200, vocab_size=128, size=1024, use_cache=False)
train, val = md.split(0.95)
fdaf = md.__getitem__(920)
print(md.print_tree(920))