Reduce memory cost when build dataset.
This commit is contained in:
parent
43883be692
commit
2d415d9e44
|
@ -19,32 +19,42 @@ class MeaningMap:
|
|||
self.vocab_size = vocab_size
|
||||
self.max_subitem = max_subitem
|
||||
self.min_subitem = min_subitem
|
||||
datastep = 0x8000000
|
||||
|
||||
path = "./data/"
|
||||
file = "structured_language_" + str(size) + "_" + str(vocab_size)
|
||||
file += "_" + str(max_subitem) + "_" + str(min_subitem)
|
||||
file = path + file + ".npz"
|
||||
file_prop = path + file + "_prop.npy"
|
||||
file_data = path + file + "_data.npy"
|
||||
file_level = path + file + "_level.npy"
|
||||
file_rank_idx = path + file + "_rank_idx.npy"
|
||||
file_rank_all = path + file + "_rank_all.npy"
|
||||
|
||||
start_time = time.time()
|
||||
if not os.path.exists(path):
|
||||
os.mkdir(path)
|
||||
if os.path.exists(file) and use_cache:
|
||||
if (
|
||||
os.path.exists(file_prop)
|
||||
and os.path.exists(file_data)
|
||||
and os.path.exists(file_level)
|
||||
and os.path.exists(file_rank_idx)
|
||||
and os.path.exists(file_rank_all)
|
||||
and use_cache
|
||||
):
|
||||
print("Load from disk cache: " + file)
|
||||
loaded = np.load(file)
|
||||
slhwm = loaded["slhwm"]
|
||||
dlra = loaded["dlra"]
|
||||
slhwm = np.load(file_prop)
|
||||
self.ms_map = slhwm[:, 4:]
|
||||
self.ms_data = dlra[:, 0]
|
||||
self.ms_data = np.load(file_data)
|
||||
self.ms_start = slhwm[:, 0]
|
||||
self.ms_len = slhwm[:, 1]
|
||||
self.ms_level = dlra[:, 1]
|
||||
self.ms_rank_idx = dlra[:, 2].astype(np.uint32)
|
||||
self.ms_rank_all = dlra[:, 3].astype(np.uint32)
|
||||
self.ms_level = np.load(file_level)
|
||||
self.ms_rank_idx = np.load(file_rank_idx)
|
||||
self.ms_rank_all = np.load(file_rank_all)
|
||||
self.ms_height = slhwm[:, 2]
|
||||
self.ms_weight = slhwm[:, 3]
|
||||
print("Load end, elapsed:" + str(time.time() - start_time) + "s")
|
||||
else:
|
||||
print("Disk cache miss, build new one.")
|
||||
print("Disk cache miss, build new one. size:" + str(size))
|
||||
|
||||
map = np.empty((size, max_subitem), dtype=np.int32)
|
||||
|
||||
|
@ -74,10 +84,10 @@ class MeaningMap:
|
|||
ms_len = np.zeros((size), dtype=np.int32) # meaning sequence len
|
||||
ms_height = np.zeros((size), dtype=np.int32) # meaning tree height
|
||||
ms_weight = np.zeros((size), dtype=np.int32) # meaning tree weight
|
||||
ms_data = np.zeros((268435456), dtype=np.int32) # meaning sequence
|
||||
ms_level = np.zeros((268435456), dtype=np.int32) # meaning level, vocab's level is 0
|
||||
ms_rank_idx = np.zeros((268435456), dtype=np.uint32) # meaning index of all level
|
||||
ms_rank_all = np.zeros((268435456), dtype=np.uint32) # meaning all of all level
|
||||
ms_data = np.zeros((datastep), dtype=np.int32) # meaning sequence
|
||||
ms_level = np.zeros((datastep), dtype=np.uint32) # meaning level, vocab's level is 0
|
||||
ms_rank_idx = np.zeros((datastep), dtype=np.uint32) # meaning index of all level
|
||||
ms_rank_all = np.zeros((datastep), dtype=np.uint32) # meaning all of all level
|
||||
|
||||
index = 0
|
||||
for i in range(self.vocab_size):
|
||||
|
@ -107,10 +117,10 @@ class MeaningMap:
|
|||
|
||||
end = index + len_ma
|
||||
if ms_data.size < end:
|
||||
ms_data = np.concatenate([ms_data, np.zeros((268435456), dtype=np.int32)])
|
||||
ms_level = np.concatenate([ms_level, np.zeros((268435456), dtype=np.int32)])
|
||||
ms_rank_idx = np.concatenate([ms_rank_idx, np.zeros((268435456), dtype=np.uint32)])
|
||||
ms_rank_all = np.concatenate([ms_rank_all, np.zeros((268435456), dtype=np.uint32)])
|
||||
ms_data = np.concatenate([ms_data, np.zeros((datastep), dtype=np.int32)])
|
||||
ms_level = np.concatenate([ms_level, np.zeros((datastep), dtype=np.uint32)])
|
||||
ms_rank_idx = np.concatenate([ms_rank_idx, np.zeros((datastep), dtype=np.uint32)])
|
||||
ms_rank_all = np.concatenate([ms_rank_all, np.zeros((datastep), dtype=np.uint32)])
|
||||
|
||||
ms_data[index:end] = ms_data[idx]
|
||||
ms_level[index:end] = ms_level[idx] + 1
|
||||
|
@ -128,37 +138,15 @@ class MeaningMap:
|
|||
|
||||
print("Mapping end, elapsed:" + str(time.time() - start_time) + "s")
|
||||
|
||||
d = np.ones(ms_rank_idx.shape, dtype=np.uint32)
|
||||
d = ((d * 0xFFFFFFFF) << (ms_level * 4)).astype(np.uint32)
|
||||
shift = (8 - ms_level) * 4
|
||||
ms_rank_idx = (
|
||||
((ms_rank_idx & 0xF) << 28)
|
||||
+ ((ms_rank_idx & 0xF0) << 20)
|
||||
+ ((ms_rank_idx & 0xF00) << 12)
|
||||
+ ((ms_rank_idx & 0xF000) << 4)
|
||||
+ ((ms_rank_idx & 0xF0000) >> 4)
|
||||
+ ((ms_rank_idx & 0xF00000) >> 12)
|
||||
+ ((ms_rank_idx & 0xF000000) >> 20)
|
||||
+ ((ms_rank_idx & 0xF0000000) >> 28)
|
||||
)
|
||||
ms_rank_idx = ((ms_rank_idx >> shift) + d).astype(np.uint32)
|
||||
ms_rank_all = (
|
||||
((ms_rank_all & 0xF) << 28)
|
||||
+ ((ms_rank_all & 0xF0) << 20)
|
||||
+ ((ms_rank_all & 0xF00) << 12)
|
||||
+ ((ms_rank_all & 0xF000) << 4)
|
||||
+ ((ms_rank_all & 0xF0000) >> 4)
|
||||
+ ((ms_rank_all & 0xF00000) >> 12)
|
||||
+ ((ms_rank_all & 0xF000000) >> 20)
|
||||
+ ((ms_rank_all & 0xF0000000) >> 28)
|
||||
)
|
||||
ms_rank_all = ((ms_rank_all >> shift) + d).astype(np.uint32)
|
||||
np.save(file_data, ms_data)
|
||||
np.save(file_level, ms_level)
|
||||
np.save(file_rank_idx, ms_rank_idx)
|
||||
np.save(file_rank_all, ms_rank_all)
|
||||
|
||||
ms_start = np.array(ms_start).astype(np.int32)
|
||||
ms_height = np.array(ms_height).astype(np.int32)
|
||||
ms_weight = np.array(ms_weight).astype(np.int32)
|
||||
ms_len = np.array(ms_len).astype(np.int32)
|
||||
|
||||
slhwm = np.concatenate(
|
||||
(
|
||||
ms_start.reshape((-1, 1)),
|
||||
|
@ -169,8 +157,7 @@ class MeaningMap:
|
|||
),
|
||||
axis=1,
|
||||
)
|
||||
dlra = np.stack((ms_data, ms_level, ms_rank_idx.astype(np.int32), ms_rank_all.astype(np.int32)), axis=1)
|
||||
np.savez(file, slhwm=slhwm, dlra=dlra)
|
||||
np.save(file_prop, slhwm)
|
||||
|
||||
self.ms_data = ms_data # map[i]=ms_data[ms_start[i]:ms_start[i]+ms_len[i]]
|
||||
self.ms_level = ms_level
|
||||
|
@ -300,11 +287,34 @@ class MeaningDataset(Dataset):
|
|||
self.tree.append({m: map.get_tree(m)})
|
||||
self.seq.append(d)
|
||||
self.level.append(l)
|
||||
self.rank_idx.append(i)
|
||||
self.rank_all.append(a)
|
||||
self.seq_meaning.append(m)
|
||||
seq_len.append(len(d))
|
||||
|
||||
dm = np.ones(i.shape, dtype=np.uint32)
|
||||
dm = ((dm * 0xFFFFFFFF) << (l * 4)).astype(np.uint32)
|
||||
shift = (8 - l) * 4
|
||||
rank_idx = (i & 0xF) << 28
|
||||
rank_idx = rank_idx + ((i & 0xF0) << 20)
|
||||
rank_idx = rank_idx + ((i & 0xF00) << 12)
|
||||
rank_idx = rank_idx + ((i & 0xF000) << 4)
|
||||
rank_idx = rank_idx + ((i & 0xF0000) >> 4)
|
||||
rank_idx = rank_idx + ((i & 0xF00000) >> 12)
|
||||
rank_idx = rank_idx + ((i & 0xF000000) >> 20)
|
||||
rank_idx = rank_idx + ((i & 0xF0000000) >> 28)
|
||||
rank_idx = ((rank_idx >> shift) + dm).astype(np.uint32)
|
||||
rank_all = (a & 0xF) << 28
|
||||
rank_all = rank_all + ((a & 0xF0) << 20)
|
||||
rank_all = rank_all + ((a & 0xF00) << 12)
|
||||
rank_all = rank_all + ((a & 0xF000) << 4)
|
||||
rank_all = rank_all + ((a & 0xF0000) >> 4)
|
||||
rank_all = rank_all + ((a & 0xF00000) >> 12)
|
||||
rank_all = rank_all + ((a & 0xF000000) >> 20)
|
||||
rank_all = rank_all + ((a & 0xF0000000) >> 28)
|
||||
rank_all = ((rank_all >> shift) + dm).astype(np.uint32)
|
||||
|
||||
self.rank_idx.append(rank_idx)
|
||||
self.rank_all.append(rank_all)
|
||||
|
||||
unique, counts = np.unique(seq_len, return_counts=True)
|
||||
print("----------------------------------------------------------------")
|
||||
print("MeaningDataset start:" + str(start) + " end:" + str(end) + " space:" + str(end - start))
|
||||
|
|
Loading…
Reference in New Issue