Save memory cost when meaning dataset build by np array.
This commit is contained in:
parent
a524d01ac3
commit
b062cc9c94
|
@ -7,6 +7,9 @@ from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.data import BatchSampler
|
from torch.utils.data import BatchSampler
|
||||||
|
|
||||||
|
# import warnings
|
||||||
|
# warnings.filterwarnings("ignore", ".*does not have many workers.*")
|
||||||
|
|
||||||
|
|
||||||
class MeaningMap:
|
class MeaningMap:
|
||||||
def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, min_subitem=1, use_cache=True):
|
def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, min_subitem=1, use_cache=True):
|
||||||
|
@ -63,63 +66,76 @@ class MeaningMap:
|
||||||
map[:vocab_size, 0] = np.arange(0, vocab_size)
|
map[:vocab_size, 0] = np.arange(0, vocab_size)
|
||||||
map[:vocab_size, 1:] = -1
|
map[:vocab_size, 1:] = -1
|
||||||
|
|
||||||
ms_data = [] # meaning sequence
|
|
||||||
ms_level = [] # meaning level, vocab's level is 0
|
ms_level = [] # meaning level, vocab's level is 0
|
||||||
ms_rank_idx = [] # meaning index of all level
|
ms_rank_idx = [] # meaning index of all level
|
||||||
ms_rank_all = [] # meaning all of all level
|
ms_rank_all = [] # meaning all of all level
|
||||||
ms_start = [] # meaning sequence start
|
ms_start = np.zeros((size), dtype=np.int32) # meaning sequence start
|
||||||
ms_len = [] # meaning sequence length
|
ms_end = np.zeros((size), dtype=np.int32) # meaning sequence end
|
||||||
ms_height = [] # meaning tree height
|
ms_len = np.zeros((size), dtype=np.int32) # meaning sequence len
|
||||||
ms_weight = [] # meaning tree weight
|
ms_height = np.zeros((size), dtype=np.int32) # meaning tree height
|
||||||
|
ms_weight = np.zeros((size), dtype=np.int32) # meaning tree weight
|
||||||
|
ms_data = np.zeros((268435456), dtype=np.int32) # meaning sequence
|
||||||
|
ms_level = np.zeros((268435456), dtype=np.int32) # meaning level, vocab's level is 0
|
||||||
|
ms_rank_idx = np.zeros((268435456), dtype=np.uint32) # meaning index of all level
|
||||||
|
ms_rank_all = np.zeros((268435456), dtype=np.uint32) # meaning all of all level
|
||||||
|
|
||||||
index = 0
|
index = 0
|
||||||
for i in range(self.vocab_size):
|
for i in range(self.vocab_size):
|
||||||
ms_data.append(np.array([i], dtype=np.int32))
|
ms_data[i] = i
|
||||||
ms_level.append(np.array([0], dtype=np.int32))
|
ms_start[i] = index
|
||||||
ms_rank_idx.append(np.array([0], dtype=np.uint32))
|
ms_end[i] = index + 1
|
||||||
ms_rank_all.append(np.array([0], dtype=np.uint32))
|
ms_len[i] = 1
|
||||||
ms_start.append(index)
|
ms_height[i] = 0
|
||||||
ms_len.append(1)
|
ms_weight[i] = 1
|
||||||
ms_height.append(0)
|
|
||||||
ms_weight.append(1)
|
|
||||||
index = index + 1
|
index = index + 1
|
||||||
|
|
||||||
for i in range(self.vocab_size, size):
|
for i in range(self.vocab_size, size):
|
||||||
m = map[i]
|
m = map[i]
|
||||||
m = m[m >= 0] # donot cut off the map such as [0]
|
m = m[m >= 0] # donot cut off the map such as [0]
|
||||||
|
m_len = len(m)
|
||||||
m_list = m.tolist()
|
m_list = m.tolist()
|
||||||
m_len = len(m_list)
|
|
||||||
assert m_list, "map list can not be empty list"
|
assert m_list, "map list can not be empty list"
|
||||||
|
ma = np.concatenate([ms_data[ms_start[newm] : ms_end[newm]] for newm in m_list])
|
||||||
|
len_ma = len(ma)
|
||||||
|
end = index + len_ma
|
||||||
|
if ms_data.size < end:
|
||||||
|
ms_data = np.concatenate([ms_data, np.zeros((268435456), dtype=np.int32)])
|
||||||
|
ms_level = np.concatenate([ms_level, np.zeros((268435456), dtype=np.int32)])
|
||||||
|
ms_rank_idx = np.concatenate([ms_rank_idx, np.zeros((268435456), dtype=np.uint32)])
|
||||||
|
ms_rank_all = np.concatenate([ms_rank_all, np.zeros((268435456), dtype=np.uint32)])
|
||||||
|
|
||||||
ma = np.concatenate([ms_data[newm] for newm in m_list])
|
ms_data[index:end] = ma
|
||||||
ml = np.concatenate([ms_level[newm] + 1 for newm in m_list])
|
ms_level[index:end] = np.concatenate([ms_level[ms_start[newm] : ms_end[newm]] + 1 for newm in m_list])
|
||||||
mr = np.concatenate(
|
ms_rank_idx[index:end] = np.concatenate(
|
||||||
[
|
[
|
||||||
([0xFFFFFFF0 + i] if newm < self.vocab_size else ms_rank_idx[newm] * 16 + i)
|
(
|
||||||
|
[0xFFFFFFF0 + i]
|
||||||
|
if newm < self.vocab_size
|
||||||
|
else ms_rank_idx[ms_start[newm] : ms_end[newm]] * 16 + i
|
||||||
|
)
|
||||||
for i, newm in enumerate(m_list)
|
for i, newm in enumerate(m_list)
|
||||||
]
|
]
|
||||||
).astype(np.uint32)
|
).astype(np.uint32)
|
||||||
mrl = np.concatenate(
|
ms_rank_all[index:end] = np.concatenate(
|
||||||
[
|
[
|
||||||
([0xFFFFFFF0 + m_len] if newm < self.vocab_size else ms_rank_all[newm] * 16 + m_len)
|
(
|
||||||
|
[0xFFFFFFF0 + m_len]
|
||||||
|
if newm < self.vocab_size
|
||||||
|
else ms_rank_all[ms_start[newm] : ms_end[newm]] * 16 + m_len
|
||||||
|
)
|
||||||
for i, newm in enumerate(m_list)
|
for i, newm in enumerate(m_list)
|
||||||
]
|
]
|
||||||
).astype(np.uint32)
|
).astype(np.uint32)
|
||||||
ms_data.append(ma)
|
ms_start[i] = index
|
||||||
ms_level.append(ml)
|
ms_end[i] = end
|
||||||
ms_rank_idx.append(mr)
|
ms_len[i] = len_ma
|
||||||
ms_rank_all.append(mrl)
|
ms_height[i] = max([ms_height[sub_m] for sub_m in m_list]) + 1
|
||||||
ms_start.append(index)
|
ms_weight[i] = sum(ms_weight[sub_m] for sub_m in m_list)
|
||||||
ms_len.append(len(ma))
|
index = index + len_ma
|
||||||
ms_height.append(max([ms_height[sub_m] for sub_m in m_list]) + 1)
|
if i % 10000 == 0:
|
||||||
ms_weight.append(sum(ms_weight[sub_m] for sub_m in m_list))
|
print(i)
|
||||||
index = index + len(ma)
|
|
||||||
|
|
||||||
print("Mapping end, elapsed:" + str(time.time() - start_time) + "s")
|
print("Mapping end, elapsed:" + str(time.time() - start_time) + "s")
|
||||||
ms_data = np.array(list(chain(*ms_data))).astype(np.int32)
|
|
||||||
ms_level = np.array(list(chain(*ms_level))).astype(np.int32)
|
|
||||||
ms_rank_idx = np.array(list(chain(*ms_rank_idx))).astype(np.uint32)
|
|
||||||
ms_rank_all = np.array(list(chain(*ms_rank_all))).astype(np.uint32)
|
|
||||||
|
|
||||||
d = np.ones(ms_rank_idx.shape, dtype=np.uint32)
|
d = np.ones(ms_rank_idx.shape, dtype=np.uint32)
|
||||||
d = ((d * 0xFFFFFFFF) << (ms_level * 4)).astype(np.uint32)
|
d = ((d * 0xFFFFFFFF) << (ms_level * 4)).astype(np.uint32)
|
||||||
|
@ -458,7 +474,7 @@ class BatchGroupMeaningDataloader(Dataset):
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=True)
|
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=False)
|
||||||
md.set_mask([1], [-1])
|
md.set_mask([1], [-1])
|
||||||
train, val = md.split(0.95)
|
train, val = md.split(0.95)
|
||||||
fdaf = md.__getitem__(920)
|
fdaf = md.__getitem__(920)
|
||||||
|
|
Loading…
Reference in New Issue