diff --git a/wit/meaning_dataset.py b/wit/meaning_dataset.py index ada1531..4eb658a 100644 --- a/wit/meaning_dataset.py +++ b/wit/meaning_dataset.py @@ -40,7 +40,7 @@ class MeaningMap: else: print("Disk cache miss, build new one.") - map = np.empty((size, max_subitem), dtype=np.uint32) + map = np.empty((size, max_subitem), dtype=np.int32) index = np.arange(0, size) map = np.random.random((size, max_subitem)) @@ -53,13 +53,12 @@ class MeaningMap: item_sum = map.sum(axis=1) scale = (index / item_sum).reshape(-1, 1).repeat(max_subitem, axis=1) - map = map * scale + map = (map * scale).astype(np.int32) - map[mask_zero] = 0 + map[mask_zero] = -1 map[:vocab_size, 0] = np.arange(0, vocab_size) - map[:vocab_size, 1:] = 0 - map = map.astype(np.uint32) + map[:vocab_size, 1:] = -1 ms_data = [] # meaning sequence ms_level = [] # meaning level, vocab's level is 0 @@ -70,33 +69,42 @@ class MeaningMap: ms_weight = [] # meaning tree weight index = 0 for i in range(self.vocab_size): - ms_data.append([i]) - ms_level.append([0]) - ms_idx.append([0]) + ms_data.append(np.array([i])) + ms_level.append(np.array([0])) + ms_idx.append(np.array([0])) ms_start.append(index) ms_len.append(1) - index = index + 1 ms_height.append(0) ms_weight.append(1) + index = index + 1 for i in range(self.vocab_size, size): m = map[i] - m = m[m > 0] - ma = [] - ml = [] - mi = [] - for i, newm in enumerate(m.tolist()): - ma = ma + ms_data[newm] - ml = ml + [x + 1 for x in ms_level[newm]] - mi = mi + ([0xFFFFFFF0 + i] if newm < self.vocab_size else [n * 16 + i for n in ms_idx[newm]]) + m = m[m >= 0] + + m_list = m.tolist() + assert m_list, "map list can not be empty list" + + ma = np.concatenate([ms_data[newm] for newm in m_list]) + ml = np.concatenate([ms_level[newm] + 1 for newm in m_list]) + mi = np.concatenate( + [ + ([0xFFFFFFF0 + i] if newm < self.vocab_size else ms_idx[newm] * 16 + i) + for i, newm in enumerate(m_list) + ] + ) + ml = ml[ma > 0] + mi = mi[ma > 0] + ma = ma[ma > 0] + ms_data.append(ma) - ms_start.append(index) - ms_len.append(len(ma)) ms_level.append(ml) ms_idx.append(mi) + ms_start.append(index) + ms_len.append(len(ma)) + ms_height.append(max([ms_height[sub_m] for sub_m in m_list]) + 1) + ms_weight.append(sum(ms_weight[sub_m] for sub_m in m_list)) index = index + len(ma) - ms_height.append(max([-1] + [ms_height[sub_m] for sub_m in m.tolist()]) + 1) - ms_weight.append(sum(ms_weight[sub_m] for sub_m in m.tolist())) # offsets = [0, 0, 4, 8, 12, 16, 20, 24, 28] # for idxmi, mi in enumerate(ms_idx): @@ -126,11 +134,10 @@ class MeaningMap: ) ms_idx = ((ms_idx >> ((8 - ms_level) * 4)) + d).astype(np.uint32) - ms_start = np.array(ms_start).astype(np.uint32) - ms_height = np.array(ms_height).astype(np.uint32) - ms_weight = np.array(ms_weight).astype(np.uint32) - ms_len = np.array(ms_len).astype(np.uint32) - ms_map = map.astype(np.uint32) + ms_start = np.array(ms_start).astype(np.int32) + ms_height = np.array(ms_height).astype(np.int32) + ms_weight = np.array(ms_weight).astype(np.int32) + ms_len = np.array(ms_len).astype(np.int32) slhwm = np.concatenate( ( @@ -138,7 +145,7 @@ class MeaningMap: ms_len.reshape((-1, 1)), ms_height.reshape((-1, 1)), ms_weight.reshape((-1, 1)), - ms_map, + map, ), axis=1, ) @@ -383,7 +390,7 @@ class BatchGroupMeaningDataloader(Dataset): if __name__ == "__main__": - md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=False) + md = MeaningDataset(100000, 115200, vocab_size=128, size=1024, use_cache=False) train, val = md.split(0.95) fdaf = md.__getitem__(920) print(md.print_tree(920))