Reduce memory cost when build dataset.
This commit is contained in:
		
							parent
							
								
									43883be692
								
							
						
					
					
						commit
						2d415d9e44
					
				| 
						 | 
					@ -19,32 +19,42 @@ class MeaningMap:
 | 
				
			||||||
        self.vocab_size = vocab_size
 | 
					        self.vocab_size = vocab_size
 | 
				
			||||||
        self.max_subitem = max_subitem
 | 
					        self.max_subitem = max_subitem
 | 
				
			||||||
        self.min_subitem = min_subitem
 | 
					        self.min_subitem = min_subitem
 | 
				
			||||||
 | 
					        datastep = 0x8000000
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        path = "./data/"
 | 
					        path = "./data/"
 | 
				
			||||||
        file = "structured_language_" + str(size) + "_" + str(vocab_size)
 | 
					        file = "structured_language_" + str(size) + "_" + str(vocab_size)
 | 
				
			||||||
        file += "_" + str(max_subitem) + "_" + str(min_subitem)
 | 
					        file += "_" + str(max_subitem) + "_" + str(min_subitem)
 | 
				
			||||||
        file = path + file + ".npz"
 | 
					        file_prop = path + file + "_prop.npy"
 | 
				
			||||||
 | 
					        file_data = path + file + "_data.npy"
 | 
				
			||||||
 | 
					        file_level = path + file + "_level.npy"
 | 
				
			||||||
 | 
					        file_rank_idx = path + file + "_rank_idx.npy"
 | 
				
			||||||
 | 
					        file_rank_all = path + file + "_rank_all.npy"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        start_time = time.time()
 | 
					        start_time = time.time()
 | 
				
			||||||
        if not os.path.exists(path):
 | 
					        if not os.path.exists(path):
 | 
				
			||||||
            os.mkdir(path)
 | 
					            os.mkdir(path)
 | 
				
			||||||
        if os.path.exists(file) and use_cache:
 | 
					        if (
 | 
				
			||||||
 | 
					            os.path.exists(file_prop)
 | 
				
			||||||
 | 
					            and os.path.exists(file_data)
 | 
				
			||||||
 | 
					            and os.path.exists(file_level)
 | 
				
			||||||
 | 
					            and os.path.exists(file_rank_idx)
 | 
				
			||||||
 | 
					            and os.path.exists(file_rank_all)
 | 
				
			||||||
 | 
					            and use_cache
 | 
				
			||||||
 | 
					        ):
 | 
				
			||||||
            print("Load from disk cache: " + file)
 | 
					            print("Load from disk cache: " + file)
 | 
				
			||||||
            loaded = np.load(file)
 | 
					            slhwm = np.load(file_prop)
 | 
				
			||||||
            slhwm = loaded["slhwm"]
 | 
					 | 
				
			||||||
            dlra = loaded["dlra"]
 | 
					 | 
				
			||||||
            self.ms_map = slhwm[:, 4:]
 | 
					            self.ms_map = slhwm[:, 4:]
 | 
				
			||||||
            self.ms_data = dlra[:, 0]
 | 
					            self.ms_data = np.load(file_data)
 | 
				
			||||||
            self.ms_start = slhwm[:, 0]
 | 
					            self.ms_start = slhwm[:, 0]
 | 
				
			||||||
            self.ms_len = slhwm[:, 1]
 | 
					            self.ms_len = slhwm[:, 1]
 | 
				
			||||||
            self.ms_level = dlra[:, 1]
 | 
					            self.ms_level = np.load(file_level)
 | 
				
			||||||
            self.ms_rank_idx = dlra[:, 2].astype(np.uint32)
 | 
					            self.ms_rank_idx = np.load(file_rank_idx)
 | 
				
			||||||
            self.ms_rank_all = dlra[:, 3].astype(np.uint32)
 | 
					            self.ms_rank_all = np.load(file_rank_all)
 | 
				
			||||||
            self.ms_height = slhwm[:, 2]
 | 
					            self.ms_height = slhwm[:, 2]
 | 
				
			||||||
            self.ms_weight = slhwm[:, 3]
 | 
					            self.ms_weight = slhwm[:, 3]
 | 
				
			||||||
            print("Load end, elapsed:" + str(time.time() - start_time) + "s")
 | 
					            print("Load end, elapsed:" + str(time.time() - start_time) + "s")
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            print("Disk cache miss, build new one.")
 | 
					            print("Disk cache miss, build new one. size:" + str(size))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            map = np.empty((size, max_subitem), dtype=np.int32)
 | 
					            map = np.empty((size, max_subitem), dtype=np.int32)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -74,10 +84,10 @@ class MeaningMap:
 | 
				
			||||||
            ms_len = np.zeros((size), dtype=np.int32)  # meaning sequence len
 | 
					            ms_len = np.zeros((size), dtype=np.int32)  # meaning sequence len
 | 
				
			||||||
            ms_height = np.zeros((size), dtype=np.int32)  # meaning tree height
 | 
					            ms_height = np.zeros((size), dtype=np.int32)  # meaning tree height
 | 
				
			||||||
            ms_weight = np.zeros((size), dtype=np.int32)  # meaning tree weight
 | 
					            ms_weight = np.zeros((size), dtype=np.int32)  # meaning tree weight
 | 
				
			||||||
            ms_data = np.zeros((268435456), dtype=np.int32)  # meaning sequence
 | 
					            ms_data = np.zeros((datastep), dtype=np.int32)  # meaning sequence
 | 
				
			||||||
            ms_level = np.zeros((268435456), dtype=np.int32)  # meaning level, vocab's level is 0
 | 
					            ms_level = np.zeros((datastep), dtype=np.uint32)  # meaning level, vocab's level is 0
 | 
				
			||||||
            ms_rank_idx = np.zeros((268435456), dtype=np.uint32)  # meaning index of all level
 | 
					            ms_rank_idx = np.zeros((datastep), dtype=np.uint32)  # meaning index of all level
 | 
				
			||||||
            ms_rank_all = np.zeros((268435456), dtype=np.uint32)  # meaning all of all level
 | 
					            ms_rank_all = np.zeros((datastep), dtype=np.uint32)  # meaning all of all level
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            index = 0
 | 
					            index = 0
 | 
				
			||||||
            for i in range(self.vocab_size):
 | 
					            for i in range(self.vocab_size):
 | 
				
			||||||
| 
						 | 
					@ -107,10 +117,10 @@ class MeaningMap:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                end = index + len_ma
 | 
					                end = index + len_ma
 | 
				
			||||||
                if ms_data.size < end:
 | 
					                if ms_data.size < end:
 | 
				
			||||||
                    ms_data = np.concatenate([ms_data, np.zeros((268435456), dtype=np.int32)])
 | 
					                    ms_data = np.concatenate([ms_data, np.zeros((datastep), dtype=np.int32)])
 | 
				
			||||||
                    ms_level = np.concatenate([ms_level, np.zeros((268435456), dtype=np.int32)])
 | 
					                    ms_level = np.concatenate([ms_level, np.zeros((datastep), dtype=np.uint32)])
 | 
				
			||||||
                    ms_rank_idx = np.concatenate([ms_rank_idx, np.zeros((268435456), dtype=np.uint32)])
 | 
					                    ms_rank_idx = np.concatenate([ms_rank_idx, np.zeros((datastep), dtype=np.uint32)])
 | 
				
			||||||
                    ms_rank_all = np.concatenate([ms_rank_all, np.zeros((268435456), dtype=np.uint32)])
 | 
					                    ms_rank_all = np.concatenate([ms_rank_all, np.zeros((datastep), dtype=np.uint32)])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                ms_data[index:end] = ms_data[idx]
 | 
					                ms_data[index:end] = ms_data[idx]
 | 
				
			||||||
                ms_level[index:end] = ms_level[idx] + 1
 | 
					                ms_level[index:end] = ms_level[idx] + 1
 | 
				
			||||||
| 
						 | 
					@ -128,37 +138,15 @@ class MeaningMap:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            print("Mapping end, elapsed:" + str(time.time() - start_time) + "s")
 | 
					            print("Mapping end, elapsed:" + str(time.time() - start_time) + "s")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            d = np.ones(ms_rank_idx.shape, dtype=np.uint32)
 | 
					            np.save(file_data, ms_data)
 | 
				
			||||||
            d = ((d * 0xFFFFFFFF) << (ms_level * 4)).astype(np.uint32)
 | 
					            np.save(file_level, ms_level)
 | 
				
			||||||
            shift = (8 - ms_level) * 4
 | 
					            np.save(file_rank_idx, ms_rank_idx)
 | 
				
			||||||
            ms_rank_idx = (
 | 
					            np.save(file_rank_all, ms_rank_all)
 | 
				
			||||||
                ((ms_rank_idx & 0xF) << 28)
 | 
					 | 
				
			||||||
                + ((ms_rank_idx & 0xF0) << 20)
 | 
					 | 
				
			||||||
                + ((ms_rank_idx & 0xF00) << 12)
 | 
					 | 
				
			||||||
                + ((ms_rank_idx & 0xF000) << 4)
 | 
					 | 
				
			||||||
                + ((ms_rank_idx & 0xF0000) >> 4)
 | 
					 | 
				
			||||||
                + ((ms_rank_idx & 0xF00000) >> 12)
 | 
					 | 
				
			||||||
                + ((ms_rank_idx & 0xF000000) >> 20)
 | 
					 | 
				
			||||||
                + ((ms_rank_idx & 0xF0000000) >> 28)
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            ms_rank_idx = ((ms_rank_idx >> shift) + d).astype(np.uint32)
 | 
					 | 
				
			||||||
            ms_rank_all = (
 | 
					 | 
				
			||||||
                ((ms_rank_all & 0xF) << 28)
 | 
					 | 
				
			||||||
                + ((ms_rank_all & 0xF0) << 20)
 | 
					 | 
				
			||||||
                + ((ms_rank_all & 0xF00) << 12)
 | 
					 | 
				
			||||||
                + ((ms_rank_all & 0xF000) << 4)
 | 
					 | 
				
			||||||
                + ((ms_rank_all & 0xF0000) >> 4)
 | 
					 | 
				
			||||||
                + ((ms_rank_all & 0xF00000) >> 12)
 | 
					 | 
				
			||||||
                + ((ms_rank_all & 0xF000000) >> 20)
 | 
					 | 
				
			||||||
                + ((ms_rank_all & 0xF0000000) >> 28)
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            ms_rank_all = ((ms_rank_all >> shift) + d).astype(np.uint32)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            ms_start = np.array(ms_start).astype(np.int32)
 | 
					            ms_start = np.array(ms_start).astype(np.int32)
 | 
				
			||||||
            ms_height = np.array(ms_height).astype(np.int32)
 | 
					            ms_height = np.array(ms_height).astype(np.int32)
 | 
				
			||||||
            ms_weight = np.array(ms_weight).astype(np.int32)
 | 
					            ms_weight = np.array(ms_weight).astype(np.int32)
 | 
				
			||||||
            ms_len = np.array(ms_len).astype(np.int32)
 | 
					            ms_len = np.array(ms_len).astype(np.int32)
 | 
				
			||||||
 | 
					 | 
				
			||||||
            slhwm = np.concatenate(
 | 
					            slhwm = np.concatenate(
 | 
				
			||||||
                (
 | 
					                (
 | 
				
			||||||
                    ms_start.reshape((-1, 1)),
 | 
					                    ms_start.reshape((-1, 1)),
 | 
				
			||||||
| 
						 | 
					@ -169,8 +157,7 @@ class MeaningMap:
 | 
				
			||||||
                ),
 | 
					                ),
 | 
				
			||||||
                axis=1,
 | 
					                axis=1,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            dlra = np.stack((ms_data, ms_level, ms_rank_idx.astype(np.int32), ms_rank_all.astype(np.int32)), axis=1)
 | 
					            np.save(file_prop, slhwm)
 | 
				
			||||||
            np.savez(file, slhwm=slhwm, dlra=dlra)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            self.ms_data = ms_data  # map[i]=ms_data[ms_start[i]:ms_start[i]+ms_len[i]]
 | 
					            self.ms_data = ms_data  # map[i]=ms_data[ms_start[i]:ms_start[i]+ms_len[i]]
 | 
				
			||||||
            self.ms_level = ms_level
 | 
					            self.ms_level = ms_level
 | 
				
			||||||
| 
						 | 
					@ -300,11 +287,34 @@ class MeaningDataset(Dataset):
 | 
				
			||||||
                self.tree.append({m: map.get_tree(m)})
 | 
					                self.tree.append({m: map.get_tree(m)})
 | 
				
			||||||
                self.seq.append(d)
 | 
					                self.seq.append(d)
 | 
				
			||||||
                self.level.append(l)
 | 
					                self.level.append(l)
 | 
				
			||||||
                self.rank_idx.append(i)
 | 
					 | 
				
			||||||
                self.rank_all.append(a)
 | 
					 | 
				
			||||||
                self.seq_meaning.append(m)
 | 
					                self.seq_meaning.append(m)
 | 
				
			||||||
                seq_len.append(len(d))
 | 
					                seq_len.append(len(d))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                dm = np.ones(i.shape, dtype=np.uint32)
 | 
				
			||||||
 | 
					                dm = ((dm * 0xFFFFFFFF) << (l * 4)).astype(np.uint32)
 | 
				
			||||||
 | 
					                shift = (8 - l) * 4
 | 
				
			||||||
 | 
					                rank_idx = (i & 0xF) << 28
 | 
				
			||||||
 | 
					                rank_idx = rank_idx + ((i & 0xF0) << 20)
 | 
				
			||||||
 | 
					                rank_idx = rank_idx + ((i & 0xF00) << 12)
 | 
				
			||||||
 | 
					                rank_idx = rank_idx + ((i & 0xF000) << 4)
 | 
				
			||||||
 | 
					                rank_idx = rank_idx + ((i & 0xF0000) >> 4)
 | 
				
			||||||
 | 
					                rank_idx = rank_idx + ((i & 0xF00000) >> 12)
 | 
				
			||||||
 | 
					                rank_idx = rank_idx + ((i & 0xF000000) >> 20)
 | 
				
			||||||
 | 
					                rank_idx = rank_idx + ((i & 0xF0000000) >> 28)
 | 
				
			||||||
 | 
					                rank_idx = ((rank_idx >> shift) + dm).astype(np.uint32)
 | 
				
			||||||
 | 
					                rank_all = (a & 0xF) << 28
 | 
				
			||||||
 | 
					                rank_all = rank_all + ((a & 0xF0) << 20)
 | 
				
			||||||
 | 
					                rank_all = rank_all + ((a & 0xF00) << 12)
 | 
				
			||||||
 | 
					                rank_all = rank_all + ((a & 0xF000) << 4)
 | 
				
			||||||
 | 
					                rank_all = rank_all + ((a & 0xF0000) >> 4)
 | 
				
			||||||
 | 
					                rank_all = rank_all + ((a & 0xF00000) >> 12)
 | 
				
			||||||
 | 
					                rank_all = rank_all + ((a & 0xF000000) >> 20)
 | 
				
			||||||
 | 
					                rank_all = rank_all + ((a & 0xF0000000) >> 28)
 | 
				
			||||||
 | 
					                rank_all = ((rank_all >> shift) + dm).astype(np.uint32)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                self.rank_idx.append(rank_idx)
 | 
				
			||||||
 | 
					                self.rank_all.append(rank_all)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        unique, counts = np.unique(seq_len, return_counts=True)
 | 
					        unique, counts = np.unique(seq_len, return_counts=True)
 | 
				
			||||||
        print("----------------------------------------------------------------")
 | 
					        print("----------------------------------------------------------------")
 | 
				
			||||||
        print("MeaningDataset start:" + str(start) + " end:" + str(end) + " space:" + str(end - start))
 | 
					        print("MeaningDataset start:" + str(start) + " end:" + str(end) + " space:" + str(end - start))
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue