diff --git a/wit/meaning_dataset.py b/wit/meaning_dataset.py index 7c7c3c1..b3c89fc 100644 --- a/wit/meaning_dataset.py +++ b/wit/meaning_dataset.py @@ -1,14 +1,11 @@ import os -import datasets -import torch -import math -import random +import torch, datasets +import math, gc, time, random, copy from itertools import chain from typing import Dict, Tuple from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split import numpy as np from torch.utils.data import BatchSampler -import copy class MeaningMap: @@ -25,6 +22,7 @@ class MeaningMap: file += "_" + str(max_subitem) + "_" + str(min_subitem) file = path + file + ".npz" + start_time = time.time() if not os.path.exists(path): os.mkdir(path) if os.path.exists(file) and use_cache: @@ -41,7 +39,7 @@ class MeaningMap: self.ms_rank_all = dlra[:, 3].astype(np.uint32) self.ms_height = slhwm[:, 2] self.ms_weight = slhwm[:, 3] - print("Load end") + print("Load end, elapsed:" + str(time.time() - start_time) + "s") else: print("Disk cache miss, build new one.") @@ -75,10 +73,10 @@ class MeaningMap: ms_weight = [] # meaning tree weight index = 0 for i in range(self.vocab_size): - ms_data.append(np.array([i])) - ms_level.append(np.array([0])) - ms_rank_idx.append(np.array([0])) - ms_rank_all.append(np.array([0])) + ms_data.append(np.array([i], dtype=np.int32)) + ms_level.append(np.array([0], dtype=np.int32)) + ms_rank_idx.append(np.array([0], dtype=np.uint32)) + ms_rank_all.append(np.array([0], dtype=np.uint32)) ms_start.append(index) ms_len.append(1) ms_height.append(0) @@ -100,13 +98,13 @@ class MeaningMap: ([0xFFFFFFF0 + i] if newm < self.vocab_size else ms_rank_idx[newm] * 16 + i) for i, newm in enumerate(m_list) ] - ) + ).astype(np.uint32) mrl = np.concatenate( [ ([0xFFFFFFF0 + m_len] if newm < self.vocab_size else ms_rank_all[newm] * 16 + m_len) for i, newm in enumerate(m_list) ] - ) + ).astype(np.uint32) ms_data.append(ma) ms_level.append(ml) ms_rank_idx.append(mr) @@ -117,6 +115,7 @@ class MeaningMap: ms_weight.append(sum(ms_weight[sub_m] for sub_m in m_list)) index = index + len(ma) + print("Mapping end, elapsed:" + str(time.time() - start_time) + "s") ms_data = np.array(list(chain(*ms_data))).astype(np.int32) ms_level = np.array(list(chain(*ms_level))).astype(np.int32) ms_rank_idx = np.array(list(chain(*ms_rank_idx))).astype(np.uint32) @@ -175,7 +174,7 @@ class MeaningMap: self.ms_len = ms_len self.ms_height = ms_height self.ms_weight = ms_weight - print("Disk cache build end.") + print("Disk cache build end, elapsed:" + str(time.time() - start_time) + "s") def get_sequence(self, meaning): # return sequence[meaning] start = self.ms_start[meaning]