From b062cc9c948843921439283edac49924a2d805a9 Mon Sep 17 00:00:00 2001 From: Colin Date: Fri, 19 Apr 2024 00:50:40 +0800 Subject: [PATCH] Save memory cost when meaning dataset build by np array. --- wit/meaning_dataset.py | 86 +++++++++++++++++++++++++----------------- 1 file changed, 51 insertions(+), 35 deletions(-) diff --git a/wit/meaning_dataset.py b/wit/meaning_dataset.py index b60be4c..c79f408 100644 --- a/wit/meaning_dataset.py +++ b/wit/meaning_dataset.py @@ -7,6 +7,9 @@ from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split import numpy as np from torch.utils.data import BatchSampler +# import warnings +# warnings.filterwarnings("ignore", ".*does not have many workers.*") + class MeaningMap: def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, min_subitem=1, use_cache=True): @@ -63,63 +66,76 @@ class MeaningMap: map[:vocab_size, 0] = np.arange(0, vocab_size) map[:vocab_size, 1:] = -1 - ms_data = [] # meaning sequence ms_level = [] # meaning level, vocab's level is 0 ms_rank_idx = [] # meaning index of all level ms_rank_all = [] # meaning all of all level - ms_start = [] # meaning sequence start - ms_len = [] # meaning sequence length - ms_height = [] # meaning tree height - ms_weight = [] # meaning tree weight + ms_start = np.zeros((size), dtype=np.int32) # meaning sequence start + ms_end = np.zeros((size), dtype=np.int32) # meaning sequence end + ms_len = np.zeros((size), dtype=np.int32) # meaning sequence len + ms_height = np.zeros((size), dtype=np.int32) # meaning tree height + ms_weight = np.zeros((size), dtype=np.int32) # meaning tree weight + ms_data = np.zeros((268435456), dtype=np.int32) # meaning sequence + ms_level = np.zeros((268435456), dtype=np.int32) # meaning level, vocab's level is 0 + ms_rank_idx = np.zeros((268435456), dtype=np.uint32) # meaning index of all level + ms_rank_all = np.zeros((268435456), dtype=np.uint32) # meaning all of all level + index = 0 for i in range(self.vocab_size): - ms_data.append(np.array([i], dtype=np.int32)) - ms_level.append(np.array([0], dtype=np.int32)) - ms_rank_idx.append(np.array([0], dtype=np.uint32)) - ms_rank_all.append(np.array([0], dtype=np.uint32)) - ms_start.append(index) - ms_len.append(1) - ms_height.append(0) - ms_weight.append(1) + ms_data[i] = i + ms_start[i] = index + ms_end[i] = index + 1 + ms_len[i] = 1 + ms_height[i] = 0 + ms_weight[i] = 1 index = index + 1 for i in range(self.vocab_size, size): m = map[i] m = m[m >= 0] # donot cut off the map such as [0] - + m_len = len(m) m_list = m.tolist() - m_len = len(m_list) assert m_list, "map list can not be empty list" + ma = np.concatenate([ms_data[ms_start[newm] : ms_end[newm]] for newm in m_list]) + len_ma = len(ma) + end = index + len_ma + if ms_data.size < end: + ms_data = np.concatenate([ms_data, np.zeros((268435456), dtype=np.int32)]) + ms_level = np.concatenate([ms_level, np.zeros((268435456), dtype=np.int32)]) + ms_rank_idx = np.concatenate([ms_rank_idx, np.zeros((268435456), dtype=np.uint32)]) + ms_rank_all = np.concatenate([ms_rank_all, np.zeros((268435456), dtype=np.uint32)]) - ma = np.concatenate([ms_data[newm] for newm in m_list]) - ml = np.concatenate([ms_level[newm] + 1 for newm in m_list]) - mr = np.concatenate( + ms_data[index:end] = ma + ms_level[index:end] = np.concatenate([ms_level[ms_start[newm] : ms_end[newm]] + 1 for newm in m_list]) + ms_rank_idx[index:end] = np.concatenate( [ - ([0xFFFFFFF0 + i] if newm < self.vocab_size else ms_rank_idx[newm] * 16 + i) + ( + [0xFFFFFFF0 + i] + if newm < self.vocab_size + else ms_rank_idx[ms_start[newm] : ms_end[newm]] * 16 + i + ) for i, newm in enumerate(m_list) ] ).astype(np.uint32) - mrl = np.concatenate( + ms_rank_all[index:end] = np.concatenate( [ - ([0xFFFFFFF0 + m_len] if newm < self.vocab_size else ms_rank_all[newm] * 16 + m_len) + ( + [0xFFFFFFF0 + m_len] + if newm < self.vocab_size + else ms_rank_all[ms_start[newm] : ms_end[newm]] * 16 + m_len + ) for i, newm in enumerate(m_list) ] ).astype(np.uint32) - ms_data.append(ma) - ms_level.append(ml) - ms_rank_idx.append(mr) - ms_rank_all.append(mrl) - ms_start.append(index) - ms_len.append(len(ma)) - ms_height.append(max([ms_height[sub_m] for sub_m in m_list]) + 1) - ms_weight.append(sum(ms_weight[sub_m] for sub_m in m_list)) - index = index + len(ma) + ms_start[i] = index + ms_end[i] = end + ms_len[i] = len_ma + ms_height[i] = max([ms_height[sub_m] for sub_m in m_list]) + 1 + ms_weight[i] = sum(ms_weight[sub_m] for sub_m in m_list) + index = index + len_ma + if i % 10000 == 0: + print(i) print("Mapping end, elapsed:" + str(time.time() - start_time) + "s") - ms_data = np.array(list(chain(*ms_data))).astype(np.int32) - ms_level = np.array(list(chain(*ms_level))).astype(np.int32) - ms_rank_idx = np.array(list(chain(*ms_rank_idx))).astype(np.uint32) - ms_rank_all = np.array(list(chain(*ms_rank_all))).astype(np.uint32) d = np.ones(ms_rank_idx.shape, dtype=np.uint32) d = ((d * 0xFFFFFFFF) << (ms_level * 4)).astype(np.uint32) @@ -458,7 +474,7 @@ class BatchGroupMeaningDataloader(Dataset): if __name__ == "__main__": - md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=True) + md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=False) md.set_mask([1], [-1]) train, val = md.split(0.95) fdaf = md.__getitem__(920)