Refine meaning dataset memory cost when building.

This commit is contained in:
Colin 2024-04-14 23:35:55 +08:00
parent c907210fc1
commit ef08359a94
1 changed files with 12 additions and 13 deletions

View File

@ -1,14 +1,11 @@
import os
import datasets
import torch
import math
import random
import torch, datasets
import math, gc, time, random, copy
from itertools import chain
from typing import Dict, Tuple
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
import numpy as np
from torch.utils.data import BatchSampler
import copy
class MeaningMap:
@ -25,6 +22,7 @@ class MeaningMap:
file += "_" + str(max_subitem) + "_" + str(min_subitem)
file = path + file + ".npz"
start_time = time.time()
if not os.path.exists(path):
os.mkdir(path)
if os.path.exists(file) and use_cache:
@ -41,7 +39,7 @@ class MeaningMap:
self.ms_rank_all = dlra[:, 3].astype(np.uint32)
self.ms_height = slhwm[:, 2]
self.ms_weight = slhwm[:, 3]
print("Load end")
print("Load end, elapsed:" + str(time.time() - start_time) + "s")
else:
print("Disk cache miss, build new one.")
@ -75,10 +73,10 @@ class MeaningMap:
ms_weight = [] # meaning tree weight
index = 0
for i in range(self.vocab_size):
ms_data.append(np.array([i]))
ms_level.append(np.array([0]))
ms_rank_idx.append(np.array([0]))
ms_rank_all.append(np.array([0]))
ms_data.append(np.array([i], dtype=np.int32))
ms_level.append(np.array([0], dtype=np.int32))
ms_rank_idx.append(np.array([0], dtype=np.uint32))
ms_rank_all.append(np.array([0], dtype=np.uint32))
ms_start.append(index)
ms_len.append(1)
ms_height.append(0)
@ -100,13 +98,13 @@ class MeaningMap:
([0xFFFFFFF0 + i] if newm < self.vocab_size else ms_rank_idx[newm] * 16 + i)
for i, newm in enumerate(m_list)
]
)
).astype(np.uint32)
mrl = np.concatenate(
[
([0xFFFFFFF0 + m_len] if newm < self.vocab_size else ms_rank_all[newm] * 16 + m_len)
for i, newm in enumerate(m_list)
]
)
).astype(np.uint32)
ms_data.append(ma)
ms_level.append(ml)
ms_rank_idx.append(mr)
@ -117,6 +115,7 @@ class MeaningMap:
ms_weight.append(sum(ms_weight[sub_m] for sub_m in m_list))
index = index + len(ma)
print("Mapping end, elapsed:" + str(time.time() - start_time) + "s")
ms_data = np.array(list(chain(*ms_data))).astype(np.int32)
ms_level = np.array(list(chain(*ms_level))).astype(np.int32)
ms_rank_idx = np.array(list(chain(*ms_rank_idx))).astype(np.uint32)
@ -175,7 +174,7 @@ class MeaningMap:
self.ms_len = ms_len
self.ms_height = ms_height
self.ms_weight = ms_weight
print("Disk cache build end.")
print("Disk cache build end, elapsed:" + str(time.time() - start_time) + "s")
def get_sequence(self, meaning): # return sequence[meaning]
start = self.ms_start[meaning]