Refine meaning dataset memory cost when building.

This commit is contained in:
Colin 2024-04-14 23:35:55 +08:00
parent c907210fc1
commit ef08359a94
1 changed files with 12 additions and 13 deletions

View File

@ -1,14 +1,11 @@
import os import os
import datasets import torch, datasets
import torch import math, gc, time, random, copy
import math
import random
from itertools import chain from itertools import chain
from typing import Dict, Tuple from typing import Dict, Tuple
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
import numpy as np import numpy as np
from torch.utils.data import BatchSampler from torch.utils.data import BatchSampler
import copy
class MeaningMap: class MeaningMap:
@ -25,6 +22,7 @@ class MeaningMap:
file += "_" + str(max_subitem) + "_" + str(min_subitem) file += "_" + str(max_subitem) + "_" + str(min_subitem)
file = path + file + ".npz" file = path + file + ".npz"
start_time = time.time()
if not os.path.exists(path): if not os.path.exists(path):
os.mkdir(path) os.mkdir(path)
if os.path.exists(file) and use_cache: if os.path.exists(file) and use_cache:
@ -41,7 +39,7 @@ class MeaningMap:
self.ms_rank_all = dlra[:, 3].astype(np.uint32) self.ms_rank_all = dlra[:, 3].astype(np.uint32)
self.ms_height = slhwm[:, 2] self.ms_height = slhwm[:, 2]
self.ms_weight = slhwm[:, 3] self.ms_weight = slhwm[:, 3]
print("Load end") print("Load end, elapsed:" + str(time.time() - start_time) + "s")
else: else:
print("Disk cache miss, build new one.") print("Disk cache miss, build new one.")
@ -75,10 +73,10 @@ class MeaningMap:
ms_weight = [] # meaning tree weight ms_weight = [] # meaning tree weight
index = 0 index = 0
for i in range(self.vocab_size): for i in range(self.vocab_size):
ms_data.append(np.array([i])) ms_data.append(np.array([i], dtype=np.int32))
ms_level.append(np.array([0])) ms_level.append(np.array([0], dtype=np.int32))
ms_rank_idx.append(np.array([0])) ms_rank_idx.append(np.array([0], dtype=np.uint32))
ms_rank_all.append(np.array([0])) ms_rank_all.append(np.array([0], dtype=np.uint32))
ms_start.append(index) ms_start.append(index)
ms_len.append(1) ms_len.append(1)
ms_height.append(0) ms_height.append(0)
@ -100,13 +98,13 @@ class MeaningMap:
([0xFFFFFFF0 + i] if newm < self.vocab_size else ms_rank_idx[newm] * 16 + i) ([0xFFFFFFF0 + i] if newm < self.vocab_size else ms_rank_idx[newm] * 16 + i)
for i, newm in enumerate(m_list) for i, newm in enumerate(m_list)
] ]
) ).astype(np.uint32)
mrl = np.concatenate( mrl = np.concatenate(
[ [
([0xFFFFFFF0 + m_len] if newm < self.vocab_size else ms_rank_all[newm] * 16 + m_len) ([0xFFFFFFF0 + m_len] if newm < self.vocab_size else ms_rank_all[newm] * 16 + m_len)
for i, newm in enumerate(m_list) for i, newm in enumerate(m_list)
] ]
) ).astype(np.uint32)
ms_data.append(ma) ms_data.append(ma)
ms_level.append(ml) ms_level.append(ml)
ms_rank_idx.append(mr) ms_rank_idx.append(mr)
@ -117,6 +115,7 @@ class MeaningMap:
ms_weight.append(sum(ms_weight[sub_m] for sub_m in m_list)) ms_weight.append(sum(ms_weight[sub_m] for sub_m in m_list))
index = index + len(ma) index = index + len(ma)
print("Mapping end, elapsed:" + str(time.time() - start_time) + "s")
ms_data = np.array(list(chain(*ms_data))).astype(np.int32) ms_data = np.array(list(chain(*ms_data))).astype(np.int32)
ms_level = np.array(list(chain(*ms_level))).astype(np.int32) ms_level = np.array(list(chain(*ms_level))).astype(np.int32)
ms_rank_idx = np.array(list(chain(*ms_rank_idx))).astype(np.uint32) ms_rank_idx = np.array(list(chain(*ms_rank_idx))).astype(np.uint32)
@ -175,7 +174,7 @@ class MeaningMap:
self.ms_len = ms_len self.ms_len = ms_len
self.ms_height = ms_height self.ms_height = ms_height
self.ms_weight = ms_weight self.ms_weight = ms_weight
print("Disk cache build end.") print("Disk cache build end, elapsed:" + str(time.time() - start_time) + "s")
def get_sequence(self, meaning): # return sequence[meaning] def get_sequence(self, meaning): # return sequence[meaning]
start = self.ms_start[meaning] start = self.ms_start[meaning]