Save memory cost when meaning dataset build by np array.

This commit is contained in:
Colin 2024-04-19 00:50:40 +08:00
parent a524d01ac3
commit b062cc9c94
1 changed files with 51 additions and 35 deletions

View File

@ -7,6 +7,9 @@ from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
import numpy as np
from torch.utils.data import BatchSampler
# import warnings
# warnings.filterwarnings("ignore", ".*does not have many workers.*")
class MeaningMap:
def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, min_subitem=1, use_cache=True):
@ -63,63 +66,76 @@ class MeaningMap:
map[:vocab_size, 0] = np.arange(0, vocab_size)
map[:vocab_size, 1:] = -1
ms_data = [] # meaning sequence
ms_level = [] # meaning level, vocab's level is 0
ms_rank_idx = [] # meaning index of all level
ms_rank_all = [] # meaning all of all level
ms_start = [] # meaning sequence start
ms_len = [] # meaning sequence length
ms_height = [] # meaning tree height
ms_weight = [] # meaning tree weight
ms_start = np.zeros((size), dtype=np.int32) # meaning sequence start
ms_end = np.zeros((size), dtype=np.int32) # meaning sequence end
ms_len = np.zeros((size), dtype=np.int32) # meaning sequence len
ms_height = np.zeros((size), dtype=np.int32) # meaning tree height
ms_weight = np.zeros((size), dtype=np.int32) # meaning tree weight
ms_data = np.zeros((268435456), dtype=np.int32) # meaning sequence
ms_level = np.zeros((268435456), dtype=np.int32) # meaning level, vocab's level is 0
ms_rank_idx = np.zeros((268435456), dtype=np.uint32) # meaning index of all level
ms_rank_all = np.zeros((268435456), dtype=np.uint32) # meaning all of all level
index = 0
for i in range(self.vocab_size):
ms_data.append(np.array([i], dtype=np.int32))
ms_level.append(np.array([0], dtype=np.int32))
ms_rank_idx.append(np.array([0], dtype=np.uint32))
ms_rank_all.append(np.array([0], dtype=np.uint32))
ms_start.append(index)
ms_len.append(1)
ms_height.append(0)
ms_weight.append(1)
ms_data[i] = i
ms_start[i] = index
ms_end[i] = index + 1
ms_len[i] = 1
ms_height[i] = 0
ms_weight[i] = 1
index = index + 1
for i in range(self.vocab_size, size):
m = map[i]
m = m[m >= 0] # donot cut off the map such as [0]
m_len = len(m)
m_list = m.tolist()
m_len = len(m_list)
assert m_list, "map list can not be empty list"
ma = np.concatenate([ms_data[ms_start[newm] : ms_end[newm]] for newm in m_list])
len_ma = len(ma)
end = index + len_ma
if ms_data.size < end:
ms_data = np.concatenate([ms_data, np.zeros((268435456), dtype=np.int32)])
ms_level = np.concatenate([ms_level, np.zeros((268435456), dtype=np.int32)])
ms_rank_idx = np.concatenate([ms_rank_idx, np.zeros((268435456), dtype=np.uint32)])
ms_rank_all = np.concatenate([ms_rank_all, np.zeros((268435456), dtype=np.uint32)])
ma = np.concatenate([ms_data[newm] for newm in m_list])
ml = np.concatenate([ms_level[newm] + 1 for newm in m_list])
mr = np.concatenate(
ms_data[index:end] = ma
ms_level[index:end] = np.concatenate([ms_level[ms_start[newm] : ms_end[newm]] + 1 for newm in m_list])
ms_rank_idx[index:end] = np.concatenate(
[
([0xFFFFFFF0 + i] if newm < self.vocab_size else ms_rank_idx[newm] * 16 + i)
(
[0xFFFFFFF0 + i]
if newm < self.vocab_size
else ms_rank_idx[ms_start[newm] : ms_end[newm]] * 16 + i
)
for i, newm in enumerate(m_list)
]
).astype(np.uint32)
mrl = np.concatenate(
ms_rank_all[index:end] = np.concatenate(
[
([0xFFFFFFF0 + m_len] if newm < self.vocab_size else ms_rank_all[newm] * 16 + m_len)
(
[0xFFFFFFF0 + m_len]
if newm < self.vocab_size
else ms_rank_all[ms_start[newm] : ms_end[newm]] * 16 + m_len
)
for i, newm in enumerate(m_list)
]
).astype(np.uint32)
ms_data.append(ma)
ms_level.append(ml)
ms_rank_idx.append(mr)
ms_rank_all.append(mrl)
ms_start.append(index)
ms_len.append(len(ma))
ms_height.append(max([ms_height[sub_m] for sub_m in m_list]) + 1)
ms_weight.append(sum(ms_weight[sub_m] for sub_m in m_list))
index = index + len(ma)
ms_start[i] = index
ms_end[i] = end
ms_len[i] = len_ma
ms_height[i] = max([ms_height[sub_m] for sub_m in m_list]) + 1
ms_weight[i] = sum(ms_weight[sub_m] for sub_m in m_list)
index = index + len_ma
if i % 10000 == 0:
print(i)
print("Mapping end, elapsed:" + str(time.time() - start_time) + "s")
ms_data = np.array(list(chain(*ms_data))).astype(np.int32)
ms_level = np.array(list(chain(*ms_level))).astype(np.int32)
ms_rank_idx = np.array(list(chain(*ms_rank_idx))).astype(np.uint32)
ms_rank_all = np.array(list(chain(*ms_rank_all))).astype(np.uint32)
d = np.ones(ms_rank_idx.shape, dtype=np.uint32)
d = ((d * 0xFFFFFFFF) << (ms_level * 4)).astype(np.uint32)
@ -458,7 +474,7 @@ class BatchGroupMeaningDataloader(Dataset):
if __name__ == "__main__":
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=True)
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=False)
md.set_mask([1], [-1])
train, val = md.split(0.95)
fdaf = md.__getitem__(920)