Save memory cost when meaning dataset build by np array.
This commit is contained in:
		
							parent
							
								
									a524d01ac3
								
							
						
					
					
						commit
						b062cc9c94
					
				|  | @ -7,6 +7,9 @@ from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split | |||
| import numpy as np | ||||
| from torch.utils.data import BatchSampler | ||||
| 
 | ||||
| # import warnings | ||||
| # warnings.filterwarnings("ignore", ".*does not have many workers.*") | ||||
| 
 | ||||
| 
 | ||||
| class MeaningMap: | ||||
|     def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, min_subitem=1, use_cache=True): | ||||
|  | @ -63,63 +66,76 @@ class MeaningMap: | |||
|             map[:vocab_size, 0] = np.arange(0, vocab_size) | ||||
|             map[:vocab_size, 1:] = -1 | ||||
| 
 | ||||
|             ms_data = []  # meaning sequence | ||||
|             ms_level = []  # meaning level, vocab's level is 0 | ||||
|             ms_rank_idx = []  # meaning index of all level | ||||
|             ms_rank_all = []  # meaning all of all level | ||||
|             ms_start = []  # meaning sequence start | ||||
|             ms_len = []  # meaning sequence length | ||||
|             ms_height = []  # meaning tree height | ||||
|             ms_weight = []  # meaning tree weight | ||||
|             ms_start = np.zeros((size), dtype=np.int32)  # meaning sequence start | ||||
|             ms_end = np.zeros((size), dtype=np.int32)  # meaning sequence end | ||||
|             ms_len = np.zeros((size), dtype=np.int32)  # meaning sequence len | ||||
|             ms_height = np.zeros((size), dtype=np.int32)  # meaning tree height | ||||
|             ms_weight = np.zeros((size), dtype=np.int32)  # meaning tree weight | ||||
|             ms_data = np.zeros((268435456), dtype=np.int32)  # meaning sequence | ||||
|             ms_level = np.zeros((268435456), dtype=np.int32)  # meaning level, vocab's level is 0 | ||||
|             ms_rank_idx = np.zeros((268435456), dtype=np.uint32)  # meaning index of all level | ||||
|             ms_rank_all = np.zeros((268435456), dtype=np.uint32)  # meaning all of all level | ||||
| 
 | ||||
|             index = 0 | ||||
|             for i in range(self.vocab_size): | ||||
|                 ms_data.append(np.array([i], dtype=np.int32)) | ||||
|                 ms_level.append(np.array([0], dtype=np.int32)) | ||||
|                 ms_rank_idx.append(np.array([0], dtype=np.uint32)) | ||||
|                 ms_rank_all.append(np.array([0], dtype=np.uint32)) | ||||
|                 ms_start.append(index) | ||||
|                 ms_len.append(1) | ||||
|                 ms_height.append(0) | ||||
|                 ms_weight.append(1) | ||||
|                 ms_data[i] = i | ||||
|                 ms_start[i] = index | ||||
|                 ms_end[i] = index + 1 | ||||
|                 ms_len[i] = 1 | ||||
|                 ms_height[i] = 0 | ||||
|                 ms_weight[i] = 1 | ||||
|                 index = index + 1 | ||||
| 
 | ||||
|             for i in range(self.vocab_size, size): | ||||
|                 m = map[i] | ||||
|                 m = m[m >= 0]  # donot cut off the map such as [0] | ||||
| 
 | ||||
|                 m_len = len(m) | ||||
|                 m_list = m.tolist() | ||||
|                 m_len = len(m_list) | ||||
|                 assert m_list, "map list can not be empty list" | ||||
|                 ma = np.concatenate([ms_data[ms_start[newm] : ms_end[newm]] for newm in m_list]) | ||||
|                 len_ma = len(ma) | ||||
|                 end = index + len_ma | ||||
|                 if ms_data.size < end: | ||||
|                     ms_data = np.concatenate([ms_data, np.zeros((268435456), dtype=np.int32)]) | ||||
|                     ms_level = np.concatenate([ms_level, np.zeros((268435456), dtype=np.int32)]) | ||||
|                     ms_rank_idx = np.concatenate([ms_rank_idx, np.zeros((268435456), dtype=np.uint32)]) | ||||
|                     ms_rank_all = np.concatenate([ms_rank_all, np.zeros((268435456), dtype=np.uint32)]) | ||||
| 
 | ||||
|                 ma = np.concatenate([ms_data[newm] for newm in m_list]) | ||||
|                 ml = np.concatenate([ms_level[newm] + 1 for newm in m_list]) | ||||
|                 mr = np.concatenate( | ||||
|                 ms_data[index:end] = ma | ||||
|                 ms_level[index:end] = np.concatenate([ms_level[ms_start[newm] : ms_end[newm]] + 1 for newm in m_list]) | ||||
|                 ms_rank_idx[index:end] = np.concatenate( | ||||
|                     [ | ||||
|                         ([0xFFFFFFF0 + i] if newm < self.vocab_size else ms_rank_idx[newm] * 16 + i) | ||||
|                         ( | ||||
|                             [0xFFFFFFF0 + i] | ||||
|                             if newm < self.vocab_size | ||||
|                             else ms_rank_idx[ms_start[newm] : ms_end[newm]] * 16 + i | ||||
|                         ) | ||||
|                         for i, newm in enumerate(m_list) | ||||
|                     ] | ||||
|                 ).astype(np.uint32) | ||||
|                 mrl = np.concatenate( | ||||
|                 ms_rank_all[index:end] = np.concatenate( | ||||
|                     [ | ||||
|                         ([0xFFFFFFF0 + m_len] if newm < self.vocab_size else ms_rank_all[newm] * 16 + m_len) | ||||
|                         ( | ||||
|                             [0xFFFFFFF0 + m_len] | ||||
|                             if newm < self.vocab_size | ||||
|                             else ms_rank_all[ms_start[newm] : ms_end[newm]] * 16 + m_len | ||||
|                         ) | ||||
|                         for i, newm in enumerate(m_list) | ||||
|                     ] | ||||
|                 ).astype(np.uint32) | ||||
|                 ms_data.append(ma) | ||||
|                 ms_level.append(ml) | ||||
|                 ms_rank_idx.append(mr) | ||||
|                 ms_rank_all.append(mrl) | ||||
|                 ms_start.append(index) | ||||
|                 ms_len.append(len(ma)) | ||||
|                 ms_height.append(max([ms_height[sub_m] for sub_m in m_list]) + 1) | ||||
|                 ms_weight.append(sum(ms_weight[sub_m] for sub_m in m_list)) | ||||
|                 index = index + len(ma) | ||||
|                 ms_start[i] = index | ||||
|                 ms_end[i] = end | ||||
|                 ms_len[i] = len_ma | ||||
|                 ms_height[i] = max([ms_height[sub_m] for sub_m in m_list]) + 1 | ||||
|                 ms_weight[i] = sum(ms_weight[sub_m] for sub_m in m_list) | ||||
|                 index = index + len_ma | ||||
|                 if i % 10000 == 0: | ||||
|                     print(i) | ||||
| 
 | ||||
|             print("Mapping end, elapsed:" + str(time.time() - start_time) + "s") | ||||
|             ms_data = np.array(list(chain(*ms_data))).astype(np.int32) | ||||
|             ms_level = np.array(list(chain(*ms_level))).astype(np.int32) | ||||
|             ms_rank_idx = np.array(list(chain(*ms_rank_idx))).astype(np.uint32) | ||||
|             ms_rank_all = np.array(list(chain(*ms_rank_all))).astype(np.uint32) | ||||
| 
 | ||||
|             d = np.ones(ms_rank_idx.shape, dtype=np.uint32) | ||||
|             d = ((d * 0xFFFFFFFF) << (ms_level * 4)).astype(np.uint32) | ||||
|  | @ -458,7 +474,7 @@ class BatchGroupMeaningDataloader(Dataset): | |||
| 
 | ||||
| if __name__ == "__main__": | ||||
| 
 | ||||
|     md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=True) | ||||
|     md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=False) | ||||
|     md.set_mask([1], [-1]) | ||||
|     train, val = md.split(0.95) | ||||
|     fdaf = md.__getitem__(920) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue