Refine meaning dataset memory cost when building.
This commit is contained in:
parent
c907210fc1
commit
ef08359a94
|
@ -1,14 +1,11 @@
|
||||||
import os
|
import os
|
||||||
import datasets
|
import torch, datasets
|
||||||
import torch
|
import math, gc, time, random, copy
|
||||||
import math
|
|
||||||
import random
|
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import Dict, Tuple
|
from typing import Dict, Tuple
|
||||||
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
|
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.data import BatchSampler
|
from torch.utils.data import BatchSampler
|
||||||
import copy
|
|
||||||
|
|
||||||
|
|
||||||
class MeaningMap:
|
class MeaningMap:
|
||||||
|
@ -25,6 +22,7 @@ class MeaningMap:
|
||||||
file += "_" + str(max_subitem) + "_" + str(min_subitem)
|
file += "_" + str(max_subitem) + "_" + str(min_subitem)
|
||||||
file = path + file + ".npz"
|
file = path + file + ".npz"
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
os.mkdir(path)
|
os.mkdir(path)
|
||||||
if os.path.exists(file) and use_cache:
|
if os.path.exists(file) and use_cache:
|
||||||
|
@ -41,7 +39,7 @@ class MeaningMap:
|
||||||
self.ms_rank_all = dlra[:, 3].astype(np.uint32)
|
self.ms_rank_all = dlra[:, 3].astype(np.uint32)
|
||||||
self.ms_height = slhwm[:, 2]
|
self.ms_height = slhwm[:, 2]
|
||||||
self.ms_weight = slhwm[:, 3]
|
self.ms_weight = slhwm[:, 3]
|
||||||
print("Load end")
|
print("Load end, elapsed:" + str(time.time() - start_time) + "s")
|
||||||
else:
|
else:
|
||||||
print("Disk cache miss, build new one.")
|
print("Disk cache miss, build new one.")
|
||||||
|
|
||||||
|
@ -75,10 +73,10 @@ class MeaningMap:
|
||||||
ms_weight = [] # meaning tree weight
|
ms_weight = [] # meaning tree weight
|
||||||
index = 0
|
index = 0
|
||||||
for i in range(self.vocab_size):
|
for i in range(self.vocab_size):
|
||||||
ms_data.append(np.array([i]))
|
ms_data.append(np.array([i], dtype=np.int32))
|
||||||
ms_level.append(np.array([0]))
|
ms_level.append(np.array([0], dtype=np.int32))
|
||||||
ms_rank_idx.append(np.array([0]))
|
ms_rank_idx.append(np.array([0], dtype=np.uint32))
|
||||||
ms_rank_all.append(np.array([0]))
|
ms_rank_all.append(np.array([0], dtype=np.uint32))
|
||||||
ms_start.append(index)
|
ms_start.append(index)
|
||||||
ms_len.append(1)
|
ms_len.append(1)
|
||||||
ms_height.append(0)
|
ms_height.append(0)
|
||||||
|
@ -100,13 +98,13 @@ class MeaningMap:
|
||||||
([0xFFFFFFF0 + i] if newm < self.vocab_size else ms_rank_idx[newm] * 16 + i)
|
([0xFFFFFFF0 + i] if newm < self.vocab_size else ms_rank_idx[newm] * 16 + i)
|
||||||
for i, newm in enumerate(m_list)
|
for i, newm in enumerate(m_list)
|
||||||
]
|
]
|
||||||
)
|
).astype(np.uint32)
|
||||||
mrl = np.concatenate(
|
mrl = np.concatenate(
|
||||||
[
|
[
|
||||||
([0xFFFFFFF0 + m_len] if newm < self.vocab_size else ms_rank_all[newm] * 16 + m_len)
|
([0xFFFFFFF0 + m_len] if newm < self.vocab_size else ms_rank_all[newm] * 16 + m_len)
|
||||||
for i, newm in enumerate(m_list)
|
for i, newm in enumerate(m_list)
|
||||||
]
|
]
|
||||||
)
|
).astype(np.uint32)
|
||||||
ms_data.append(ma)
|
ms_data.append(ma)
|
||||||
ms_level.append(ml)
|
ms_level.append(ml)
|
||||||
ms_rank_idx.append(mr)
|
ms_rank_idx.append(mr)
|
||||||
|
@ -117,6 +115,7 @@ class MeaningMap:
|
||||||
ms_weight.append(sum(ms_weight[sub_m] for sub_m in m_list))
|
ms_weight.append(sum(ms_weight[sub_m] for sub_m in m_list))
|
||||||
index = index + len(ma)
|
index = index + len(ma)
|
||||||
|
|
||||||
|
print("Mapping end, elapsed:" + str(time.time() - start_time) + "s")
|
||||||
ms_data = np.array(list(chain(*ms_data))).astype(np.int32)
|
ms_data = np.array(list(chain(*ms_data))).astype(np.int32)
|
||||||
ms_level = np.array(list(chain(*ms_level))).astype(np.int32)
|
ms_level = np.array(list(chain(*ms_level))).astype(np.int32)
|
||||||
ms_rank_idx = np.array(list(chain(*ms_rank_idx))).astype(np.uint32)
|
ms_rank_idx = np.array(list(chain(*ms_rank_idx))).astype(np.uint32)
|
||||||
|
@ -175,7 +174,7 @@ class MeaningMap:
|
||||||
self.ms_len = ms_len
|
self.ms_len = ms_len
|
||||||
self.ms_height = ms_height
|
self.ms_height = ms_height
|
||||||
self.ms_weight = ms_weight
|
self.ms_weight = ms_weight
|
||||||
print("Disk cache build end.")
|
print("Disk cache build end, elapsed:" + str(time.time() - start_time) + "s")
|
||||||
|
|
||||||
def get_sequence(self, meaning): # return sequence[meaning]
|
def get_sequence(self, meaning): # return sequence[meaning]
|
||||||
start = self.ms_start[meaning]
|
start = self.ms_start[meaning]
|
||||||
|
|
Loading…
Reference in New Issue