Refine meaning dataset memory cost when building.
This commit is contained in:
parent
c907210fc1
commit
ef08359a94
|
@ -1,14 +1,11 @@
|
|||
import os
|
||||
import datasets
|
||||
import torch
|
||||
import math
|
||||
import random
|
||||
import torch, datasets
|
||||
import math, gc, time, random, copy
|
||||
from itertools import chain
|
||||
from typing import Dict, Tuple
|
||||
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
|
||||
import numpy as np
|
||||
from torch.utils.data import BatchSampler
|
||||
import copy
|
||||
|
||||
|
||||
class MeaningMap:
|
||||
|
@ -25,6 +22,7 @@ class MeaningMap:
|
|||
file += "_" + str(max_subitem) + "_" + str(min_subitem)
|
||||
file = path + file + ".npz"
|
||||
|
||||
start_time = time.time()
|
||||
if not os.path.exists(path):
|
||||
os.mkdir(path)
|
||||
if os.path.exists(file) and use_cache:
|
||||
|
@ -41,7 +39,7 @@ class MeaningMap:
|
|||
self.ms_rank_all = dlra[:, 3].astype(np.uint32)
|
||||
self.ms_height = slhwm[:, 2]
|
||||
self.ms_weight = slhwm[:, 3]
|
||||
print("Load end")
|
||||
print("Load end, elapsed:" + str(time.time() - start_time) + "s")
|
||||
else:
|
||||
print("Disk cache miss, build new one.")
|
||||
|
||||
|
@ -75,10 +73,10 @@ class MeaningMap:
|
|||
ms_weight = [] # meaning tree weight
|
||||
index = 0
|
||||
for i in range(self.vocab_size):
|
||||
ms_data.append(np.array([i]))
|
||||
ms_level.append(np.array([0]))
|
||||
ms_rank_idx.append(np.array([0]))
|
||||
ms_rank_all.append(np.array([0]))
|
||||
ms_data.append(np.array([i], dtype=np.int32))
|
||||
ms_level.append(np.array([0], dtype=np.int32))
|
||||
ms_rank_idx.append(np.array([0], dtype=np.uint32))
|
||||
ms_rank_all.append(np.array([0], dtype=np.uint32))
|
||||
ms_start.append(index)
|
||||
ms_len.append(1)
|
||||
ms_height.append(0)
|
||||
|
@ -100,13 +98,13 @@ class MeaningMap:
|
|||
([0xFFFFFFF0 + i] if newm < self.vocab_size else ms_rank_idx[newm] * 16 + i)
|
||||
for i, newm in enumerate(m_list)
|
||||
]
|
||||
)
|
||||
).astype(np.uint32)
|
||||
mrl = np.concatenate(
|
||||
[
|
||||
([0xFFFFFFF0 + m_len] if newm < self.vocab_size else ms_rank_all[newm] * 16 + m_len)
|
||||
for i, newm in enumerate(m_list)
|
||||
]
|
||||
)
|
||||
).astype(np.uint32)
|
||||
ms_data.append(ma)
|
||||
ms_level.append(ml)
|
||||
ms_rank_idx.append(mr)
|
||||
|
@ -117,6 +115,7 @@ class MeaningMap:
|
|||
ms_weight.append(sum(ms_weight[sub_m] for sub_m in m_list))
|
||||
index = index + len(ma)
|
||||
|
||||
print("Mapping end, elapsed:" + str(time.time() - start_time) + "s")
|
||||
ms_data = np.array(list(chain(*ms_data))).astype(np.int32)
|
||||
ms_level = np.array(list(chain(*ms_level))).astype(np.int32)
|
||||
ms_rank_idx = np.array(list(chain(*ms_rank_idx))).astype(np.uint32)
|
||||
|
@ -175,7 +174,7 @@ class MeaningMap:
|
|||
self.ms_len = ms_len
|
||||
self.ms_height = ms_height
|
||||
self.ms_weight = ms_weight
|
||||
print("Disk cache build end.")
|
||||
print("Disk cache build end, elapsed:" + str(time.time() - start_time) + "s")
|
||||
|
||||
def get_sequence(self, meaning): # return sequence[meaning]
|
||||
start = self.ms_start[meaning]
|
||||
|
|
Loading…
Reference in New Issue