Save memory cost when meaning dataset build by np array.
This commit is contained in:
parent
a524d01ac3
commit
b062cc9c94
|
@ -7,6 +7,9 @@ from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
|
|||
import numpy as np
|
||||
from torch.utils.data import BatchSampler
|
||||
|
||||
# import warnings
|
||||
# warnings.filterwarnings("ignore", ".*does not have many workers.*")
|
||||
|
||||
|
||||
class MeaningMap:
|
||||
def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, min_subitem=1, use_cache=True):
|
||||
|
@ -63,63 +66,76 @@ class MeaningMap:
|
|||
map[:vocab_size, 0] = np.arange(0, vocab_size)
|
||||
map[:vocab_size, 1:] = -1
|
||||
|
||||
ms_data = [] # meaning sequence
|
||||
ms_level = [] # meaning level, vocab's level is 0
|
||||
ms_rank_idx = [] # meaning index of all level
|
||||
ms_rank_all = [] # meaning all of all level
|
||||
ms_start = [] # meaning sequence start
|
||||
ms_len = [] # meaning sequence length
|
||||
ms_height = [] # meaning tree height
|
||||
ms_weight = [] # meaning tree weight
|
||||
ms_start = np.zeros((size), dtype=np.int32) # meaning sequence start
|
||||
ms_end = np.zeros((size), dtype=np.int32) # meaning sequence end
|
||||
ms_len = np.zeros((size), dtype=np.int32) # meaning sequence len
|
||||
ms_height = np.zeros((size), dtype=np.int32) # meaning tree height
|
||||
ms_weight = np.zeros((size), dtype=np.int32) # meaning tree weight
|
||||
ms_data = np.zeros((268435456), dtype=np.int32) # meaning sequence
|
||||
ms_level = np.zeros((268435456), dtype=np.int32) # meaning level, vocab's level is 0
|
||||
ms_rank_idx = np.zeros((268435456), dtype=np.uint32) # meaning index of all level
|
||||
ms_rank_all = np.zeros((268435456), dtype=np.uint32) # meaning all of all level
|
||||
|
||||
index = 0
|
||||
for i in range(self.vocab_size):
|
||||
ms_data.append(np.array([i], dtype=np.int32))
|
||||
ms_level.append(np.array([0], dtype=np.int32))
|
||||
ms_rank_idx.append(np.array([0], dtype=np.uint32))
|
||||
ms_rank_all.append(np.array([0], dtype=np.uint32))
|
||||
ms_start.append(index)
|
||||
ms_len.append(1)
|
||||
ms_height.append(0)
|
||||
ms_weight.append(1)
|
||||
ms_data[i] = i
|
||||
ms_start[i] = index
|
||||
ms_end[i] = index + 1
|
||||
ms_len[i] = 1
|
||||
ms_height[i] = 0
|
||||
ms_weight[i] = 1
|
||||
index = index + 1
|
||||
|
||||
for i in range(self.vocab_size, size):
|
||||
m = map[i]
|
||||
m = m[m >= 0] # donot cut off the map such as [0]
|
||||
|
||||
m_len = len(m)
|
||||
m_list = m.tolist()
|
||||
m_len = len(m_list)
|
||||
assert m_list, "map list can not be empty list"
|
||||
ma = np.concatenate([ms_data[ms_start[newm] : ms_end[newm]] for newm in m_list])
|
||||
len_ma = len(ma)
|
||||
end = index + len_ma
|
||||
if ms_data.size < end:
|
||||
ms_data = np.concatenate([ms_data, np.zeros((268435456), dtype=np.int32)])
|
||||
ms_level = np.concatenate([ms_level, np.zeros((268435456), dtype=np.int32)])
|
||||
ms_rank_idx = np.concatenate([ms_rank_idx, np.zeros((268435456), dtype=np.uint32)])
|
||||
ms_rank_all = np.concatenate([ms_rank_all, np.zeros((268435456), dtype=np.uint32)])
|
||||
|
||||
ma = np.concatenate([ms_data[newm] for newm in m_list])
|
||||
ml = np.concatenate([ms_level[newm] + 1 for newm in m_list])
|
||||
mr = np.concatenate(
|
||||
ms_data[index:end] = ma
|
||||
ms_level[index:end] = np.concatenate([ms_level[ms_start[newm] : ms_end[newm]] + 1 for newm in m_list])
|
||||
ms_rank_idx[index:end] = np.concatenate(
|
||||
[
|
||||
([0xFFFFFFF0 + i] if newm < self.vocab_size else ms_rank_idx[newm] * 16 + i)
|
||||
(
|
||||
[0xFFFFFFF0 + i]
|
||||
if newm < self.vocab_size
|
||||
else ms_rank_idx[ms_start[newm] : ms_end[newm]] * 16 + i
|
||||
)
|
||||
for i, newm in enumerate(m_list)
|
||||
]
|
||||
).astype(np.uint32)
|
||||
mrl = np.concatenate(
|
||||
ms_rank_all[index:end] = np.concatenate(
|
||||
[
|
||||
([0xFFFFFFF0 + m_len] if newm < self.vocab_size else ms_rank_all[newm] * 16 + m_len)
|
||||
(
|
||||
[0xFFFFFFF0 + m_len]
|
||||
if newm < self.vocab_size
|
||||
else ms_rank_all[ms_start[newm] : ms_end[newm]] * 16 + m_len
|
||||
)
|
||||
for i, newm in enumerate(m_list)
|
||||
]
|
||||
).astype(np.uint32)
|
||||
ms_data.append(ma)
|
||||
ms_level.append(ml)
|
||||
ms_rank_idx.append(mr)
|
||||
ms_rank_all.append(mrl)
|
||||
ms_start.append(index)
|
||||
ms_len.append(len(ma))
|
||||
ms_height.append(max([ms_height[sub_m] for sub_m in m_list]) + 1)
|
||||
ms_weight.append(sum(ms_weight[sub_m] for sub_m in m_list))
|
||||
index = index + len(ma)
|
||||
ms_start[i] = index
|
||||
ms_end[i] = end
|
||||
ms_len[i] = len_ma
|
||||
ms_height[i] = max([ms_height[sub_m] for sub_m in m_list]) + 1
|
||||
ms_weight[i] = sum(ms_weight[sub_m] for sub_m in m_list)
|
||||
index = index + len_ma
|
||||
if i % 10000 == 0:
|
||||
print(i)
|
||||
|
||||
print("Mapping end, elapsed:" + str(time.time() - start_time) + "s")
|
||||
ms_data = np.array(list(chain(*ms_data))).astype(np.int32)
|
||||
ms_level = np.array(list(chain(*ms_level))).astype(np.int32)
|
||||
ms_rank_idx = np.array(list(chain(*ms_rank_idx))).astype(np.uint32)
|
||||
ms_rank_all = np.array(list(chain(*ms_rank_all))).astype(np.uint32)
|
||||
|
||||
d = np.ones(ms_rank_idx.shape, dtype=np.uint32)
|
||||
d = ((d * 0xFFFFFFFF) << (ms_level * 4)).astype(np.uint32)
|
||||
|
@ -458,7 +474,7 @@ class BatchGroupMeaningDataloader(Dataset):
|
|||
|
||||
if __name__ == "__main__":
|
||||
|
||||
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=True)
|
||||
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=False)
|
||||
md.set_mask([1], [-1])
|
||||
train, val = md.split(0.95)
|
||||
fdaf = md.__getitem__(920)
|
||||
|
|
Loading…
Reference in New Issue