diff --git a/wit/meaning_dataset.py b/wit/meaning_dataset.py index c79f408..65a85db 100644 --- a/wit/meaning_dataset.py +++ b/wit/meaning_dataset.py @@ -85,6 +85,9 @@ class MeaningMap: ms_start[i] = index ms_end[i] = index + 1 ms_len[i] = 1 + ms_level[i] = 0 + ms_rank_idx[i] = 0xFFFFFFF + ms_rank_all[i] = 0xFFFFFFF ms_height[i] = 0 ms_weight[i] = 1 index = index + 1 @@ -95,8 +98,13 @@ class MeaningMap: m_len = len(m) m_list = m.tolist() assert m_list, "map list can not be empty list" - ma = np.concatenate([ms_data[ms_start[newm] : ms_end[newm]] for newm in m_list]) - len_ma = len(ma) + + idx = np.concatenate([np.arange(ms_start[m], ms_end[m]) for m in m_list]) + idxidx = np.concatenate( + [np.ones(l, dtype=np.uint32) * i for i, l in enumerate(ms_end[m_list] - ms_start[m_list])] + ) + len_ma = len(idx) + end = index + len_ma if ms_data.size < end: ms_data = np.concatenate([ms_data, np.zeros((268435456), dtype=np.int32)]) @@ -104,33 +112,16 @@ class MeaningMap: ms_rank_idx = np.concatenate([ms_rank_idx, np.zeros((268435456), dtype=np.uint32)]) ms_rank_all = np.concatenate([ms_rank_all, np.zeros((268435456), dtype=np.uint32)]) - ms_data[index:end] = ma - ms_level[index:end] = np.concatenate([ms_level[ms_start[newm] : ms_end[newm]] + 1 for newm in m_list]) - ms_rank_idx[index:end] = np.concatenate( - [ - ( - [0xFFFFFFF0 + i] - if newm < self.vocab_size - else ms_rank_idx[ms_start[newm] : ms_end[newm]] * 16 + i - ) - for i, newm in enumerate(m_list) - ] - ).astype(np.uint32) - ms_rank_all[index:end] = np.concatenate( - [ - ( - [0xFFFFFFF0 + m_len] - if newm < self.vocab_size - else ms_rank_all[ms_start[newm] : ms_end[newm]] * 16 + m_len - ) - for i, newm in enumerate(m_list) - ] - ).astype(np.uint32) + ms_data[index:end] = ms_data[idx] + ms_level[index:end] = ms_level[idx] + 1 + ms_rank_idx[index:end] = (ms_rank_idx[idx] * 16 + idxidx).astype(np.uint32) + ms_rank_all[index:end] = (ms_rank_all[idx] * 16 + m_len).astype(np.uint32) + ms_start[i] = index ms_end[i] = end ms_len[i] = len_ma - ms_height[i] = max([ms_height[sub_m] for sub_m in m_list]) + 1 - ms_weight[i] = sum(ms_weight[sub_m] for sub_m in m_list) + ms_height[i] = max(ms_height[m_list]) + 1 + ms_weight[i] = sum(ms_weight[m_list]) index = index + len_ma if i % 10000 == 0: print(i) @@ -139,6 +130,7 @@ class MeaningMap: d = np.ones(ms_rank_idx.shape, dtype=np.uint32) d = ((d * 0xFFFFFFFF) << (ms_level * 4)).astype(np.uint32) + shift = (8 - ms_level) * 4 ms_rank_idx = ( ((ms_rank_idx & 0xF) << 28) + ((ms_rank_idx & 0xF0) << 20) @@ -149,7 +141,7 @@ class MeaningMap: + ((ms_rank_idx & 0xF000000) >> 20) + ((ms_rank_idx & 0xF0000000) >> 28) ) - ms_rank_idx = ((ms_rank_idx >> ((8 - ms_level) * 4)) + d).astype(np.uint32) + ms_rank_idx = ((ms_rank_idx >> shift) + d).astype(np.uint32) ms_rank_all = ( ((ms_rank_all & 0xF) << 28) + ((ms_rank_all & 0xF0) << 20) @@ -160,7 +152,7 @@ class MeaningMap: + ((ms_rank_all & 0xF000000) >> 20) + ((ms_rank_all & 0xF0000000) >> 28) ) - ms_rank_all = ((ms_rank_all >> ((8 - ms_level) * 4)) + d).astype(np.uint32) + ms_rank_all = ((ms_rank_all >> shift) + d).astype(np.uint32) ms_start = np.array(ms_start).astype(np.int32) ms_height = np.array(ms_height).astype(np.int32)