Reduce memory cost when build dataset.

This commit is contained in:
Colin 2024-04-19 16:57:59 +08:00
parent 43883be692
commit 2d415d9e44
1 changed files with 58 additions and 48 deletions

View File

@ -19,32 +19,42 @@ class MeaningMap:
self.vocab_size = vocab_size
self.max_subitem = max_subitem
self.min_subitem = min_subitem
datastep = 0x8000000
path = "./data/"
file = "structured_language_" + str(size) + "_" + str(vocab_size)
file += "_" + str(max_subitem) + "_" + str(min_subitem)
file = path + file + ".npz"
file_prop = path + file + "_prop.npy"
file_data = path + file + "_data.npy"
file_level = path + file + "_level.npy"
file_rank_idx = path + file + "_rank_idx.npy"
file_rank_all = path + file + "_rank_all.npy"
start_time = time.time()
if not os.path.exists(path):
os.mkdir(path)
if os.path.exists(file) and use_cache:
if (
os.path.exists(file_prop)
and os.path.exists(file_data)
and os.path.exists(file_level)
and os.path.exists(file_rank_idx)
and os.path.exists(file_rank_all)
and use_cache
):
print("Load from disk cache: " + file)
loaded = np.load(file)
slhwm = loaded["slhwm"]
dlra = loaded["dlra"]
slhwm = np.load(file_prop)
self.ms_map = slhwm[:, 4:]
self.ms_data = dlra[:, 0]
self.ms_data = np.load(file_data)
self.ms_start = slhwm[:, 0]
self.ms_len = slhwm[:, 1]
self.ms_level = dlra[:, 1]
self.ms_rank_idx = dlra[:, 2].astype(np.uint32)
self.ms_rank_all = dlra[:, 3].astype(np.uint32)
self.ms_level = np.load(file_level)
self.ms_rank_idx = np.load(file_rank_idx)
self.ms_rank_all = np.load(file_rank_all)
self.ms_height = slhwm[:, 2]
self.ms_weight = slhwm[:, 3]
print("Load end, elapsed:" + str(time.time() - start_time) + "s")
else:
print("Disk cache miss, build new one.")
print("Disk cache miss, build new one. size:" + str(size))
map = np.empty((size, max_subitem), dtype=np.int32)
@ -74,10 +84,10 @@ class MeaningMap:
ms_len = np.zeros((size), dtype=np.int32) # meaning sequence len
ms_height = np.zeros((size), dtype=np.int32) # meaning tree height
ms_weight = np.zeros((size), dtype=np.int32) # meaning tree weight
ms_data = np.zeros((268435456), dtype=np.int32) # meaning sequence
ms_level = np.zeros((268435456), dtype=np.int32) # meaning level, vocab's level is 0
ms_rank_idx = np.zeros((268435456), dtype=np.uint32) # meaning index of all level
ms_rank_all = np.zeros((268435456), dtype=np.uint32) # meaning all of all level
ms_data = np.zeros((datastep), dtype=np.int32) # meaning sequence
ms_level = np.zeros((datastep), dtype=np.uint32) # meaning level, vocab's level is 0
ms_rank_idx = np.zeros((datastep), dtype=np.uint32) # meaning index of all level
ms_rank_all = np.zeros((datastep), dtype=np.uint32) # meaning all of all level
index = 0
for i in range(self.vocab_size):
@ -107,10 +117,10 @@ class MeaningMap:
end = index + len_ma
if ms_data.size < end:
ms_data = np.concatenate([ms_data, np.zeros((268435456), dtype=np.int32)])
ms_level = np.concatenate([ms_level, np.zeros((268435456), dtype=np.int32)])
ms_rank_idx = np.concatenate([ms_rank_idx, np.zeros((268435456), dtype=np.uint32)])
ms_rank_all = np.concatenate([ms_rank_all, np.zeros((268435456), dtype=np.uint32)])
ms_data = np.concatenate([ms_data, np.zeros((datastep), dtype=np.int32)])
ms_level = np.concatenate([ms_level, np.zeros((datastep), dtype=np.uint32)])
ms_rank_idx = np.concatenate([ms_rank_idx, np.zeros((datastep), dtype=np.uint32)])
ms_rank_all = np.concatenate([ms_rank_all, np.zeros((datastep), dtype=np.uint32)])
ms_data[index:end] = ms_data[idx]
ms_level[index:end] = ms_level[idx] + 1
@ -128,37 +138,15 @@ class MeaningMap:
print("Mapping end, elapsed:" + str(time.time() - start_time) + "s")
d = np.ones(ms_rank_idx.shape, dtype=np.uint32)
d = ((d * 0xFFFFFFFF) << (ms_level * 4)).astype(np.uint32)
shift = (8 - ms_level) * 4
ms_rank_idx = (
((ms_rank_idx & 0xF) << 28)
+ ((ms_rank_idx & 0xF0) << 20)
+ ((ms_rank_idx & 0xF00) << 12)
+ ((ms_rank_idx & 0xF000) << 4)
+ ((ms_rank_idx & 0xF0000) >> 4)
+ ((ms_rank_idx & 0xF00000) >> 12)
+ ((ms_rank_idx & 0xF000000) >> 20)
+ ((ms_rank_idx & 0xF0000000) >> 28)
)
ms_rank_idx = ((ms_rank_idx >> shift) + d).astype(np.uint32)
ms_rank_all = (
((ms_rank_all & 0xF) << 28)
+ ((ms_rank_all & 0xF0) << 20)
+ ((ms_rank_all & 0xF00) << 12)
+ ((ms_rank_all & 0xF000) << 4)
+ ((ms_rank_all & 0xF0000) >> 4)
+ ((ms_rank_all & 0xF00000) >> 12)
+ ((ms_rank_all & 0xF000000) >> 20)
+ ((ms_rank_all & 0xF0000000) >> 28)
)
ms_rank_all = ((ms_rank_all >> shift) + d).astype(np.uint32)
np.save(file_data, ms_data)
np.save(file_level, ms_level)
np.save(file_rank_idx, ms_rank_idx)
np.save(file_rank_all, ms_rank_all)
ms_start = np.array(ms_start).astype(np.int32)
ms_height = np.array(ms_height).astype(np.int32)
ms_weight = np.array(ms_weight).astype(np.int32)
ms_len = np.array(ms_len).astype(np.int32)
slhwm = np.concatenate(
(
ms_start.reshape((-1, 1)),
@ -169,8 +157,7 @@ class MeaningMap:
),
axis=1,
)
dlra = np.stack((ms_data, ms_level, ms_rank_idx.astype(np.int32), ms_rank_all.astype(np.int32)), axis=1)
np.savez(file, slhwm=slhwm, dlra=dlra)
np.save(file_prop, slhwm)
self.ms_data = ms_data # map[i]=ms_data[ms_start[i]:ms_start[i]+ms_len[i]]
self.ms_level = ms_level
@ -300,11 +287,34 @@ class MeaningDataset(Dataset):
self.tree.append({m: map.get_tree(m)})
self.seq.append(d)
self.level.append(l)
self.rank_idx.append(i)
self.rank_all.append(a)
self.seq_meaning.append(m)
seq_len.append(len(d))
dm = np.ones(i.shape, dtype=np.uint32)
dm = ((dm * 0xFFFFFFFF) << (l * 4)).astype(np.uint32)
shift = (8 - l) * 4
rank_idx = (i & 0xF) << 28
rank_idx = rank_idx + ((i & 0xF0) << 20)
rank_idx = rank_idx + ((i & 0xF00) << 12)
rank_idx = rank_idx + ((i & 0xF000) << 4)
rank_idx = rank_idx + ((i & 0xF0000) >> 4)
rank_idx = rank_idx + ((i & 0xF00000) >> 12)
rank_idx = rank_idx + ((i & 0xF000000) >> 20)
rank_idx = rank_idx + ((i & 0xF0000000) >> 28)
rank_idx = ((rank_idx >> shift) + dm).astype(np.uint32)
rank_all = (a & 0xF) << 28
rank_all = rank_all + ((a & 0xF0) << 20)
rank_all = rank_all + ((a & 0xF00) << 12)
rank_all = rank_all + ((a & 0xF000) << 4)
rank_all = rank_all + ((a & 0xF0000) >> 4)
rank_all = rank_all + ((a & 0xF00000) >> 12)
rank_all = rank_all + ((a & 0xF000000) >> 20)
rank_all = rank_all + ((a & 0xF0000000) >> 28)
rank_all = ((rank_all >> shift) + dm).astype(np.uint32)
self.rank_idx.append(rank_idx)
self.rank_all.append(rank_all)
unique, counts = np.unique(seq_len, return_counts=True)
print("----------------------------------------------------------------")
print("MeaningDataset start:" + str(start) + " end:" + str(end) + " space:" + str(end - start))