Add mapping output.
This commit is contained in:
parent
e2b48c0ab4
commit
a15e55bead
|
@ -10,7 +10,7 @@ import numpy as np
|
||||||
from torch.utils.data import BatchSampler
|
from torch.utils.data import BatchSampler
|
||||||
|
|
||||||
|
|
||||||
class MeaningMap: # 16777216 1048576 8192
|
class MeaningMap:
|
||||||
|
|
||||||
def __init__(self, size=1048576, vocab_size=4096, max_subitem=10):
|
def __init__(self, size=1048576, vocab_size=4096, max_subitem=10):
|
||||||
self.size = size
|
self.size = size
|
||||||
|
@ -20,14 +20,21 @@ class MeaningMap: # 16777216 1048576 8192
|
||||||
path = "./data/"
|
path = "./data/"
|
||||||
file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem)
|
file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem)
|
||||||
file = path + file
|
file = path + file
|
||||||
|
file_map = file + "_map" + ".npy"
|
||||||
file_start = file + "_start" + ".npy"
|
file_start = file + "_start" + ".npy"
|
||||||
file_len = file + "_len" + ".npy"
|
file_len = file + "_len" + ".npy"
|
||||||
file_data = file + "_data" + ".npy"
|
file_data = file + "_data" + ".npy"
|
||||||
|
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
os.mkdir(path)
|
os.mkdir(path)
|
||||||
if os.path.exists(file_start) and os.path.exists(file_len) and os.path.exists(file_data):
|
if (
|
||||||
|
os.path.exists(file_start)
|
||||||
|
and os.path.exists(file_len)
|
||||||
|
and os.path.exists(file_data)
|
||||||
|
and os.path.exists(file_map)
|
||||||
|
):
|
||||||
print("Load from disk cache: " + file)
|
print("Load from disk cache: " + file)
|
||||||
|
self.ms_map = np.load(file_map)
|
||||||
self.ms_data = np.load(file_data)
|
self.ms_data = np.load(file_data)
|
||||||
self.ms_start = np.load(file_start)
|
self.ms_start = np.load(file_start)
|
||||||
self.ms_len = np.load(file_len)
|
self.ms_len = np.load(file_len)
|
||||||
|
@ -77,21 +84,30 @@ class MeaningMap: # 16777216 1048576 8192
|
||||||
index = index + len(ma)
|
index = index + len(ma)
|
||||||
|
|
||||||
ms_data = list(chain(*ms))
|
ms_data = list(chain(*ms))
|
||||||
|
np.save(file_map, np.array(mm).astype(np.int32))
|
||||||
np.save(file_data, np.array(ms_data).astype(np.int32))
|
np.save(file_data, np.array(ms_data).astype(np.int32))
|
||||||
np.save(file_start, np.array(ms_start).astype(np.int32))
|
np.save(file_start, np.array(ms_start).astype(np.int32))
|
||||||
np.save(file_len, np.array(ms_len).astype(np.int32))
|
np.save(file_len, np.array(ms_len).astype(np.int32))
|
||||||
|
|
||||||
|
self.ms_map = mm
|
||||||
self.ms_data = ms_data
|
self.ms_data = ms_data
|
||||||
self.ms_start = ms_start
|
self.ms_start = ms_start
|
||||||
self.ms_len = ms_len
|
self.ms_len = ms_len
|
||||||
print("Disk cache build end.")
|
print("Disk cache build end.")
|
||||||
|
|
||||||
def GetSequence(self, meaning):
|
def get_sequence(self, meaning):
|
||||||
start = self.ms_start[meaning]
|
start = self.ms_start[meaning]
|
||||||
len = self.ms_len[meaning]
|
len = self.ms_len[meaning]
|
||||||
return self.ms_data[start : start + len]
|
return self.ms_data[start : start + len]
|
||||||
|
|
||||||
def MaxLength(self):
|
def get_mapping(self, meaning):
|
||||||
|
mapping = {}
|
||||||
|
ms = self.ms_map[meaning]
|
||||||
|
for m in ms[ms > 0].tolist():
|
||||||
|
mapping[m] = self.get_mapping(m) if m >= self.vocab_size else m
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
def max_length(self):
|
||||||
return max(self.ms_len)
|
return max(self.ms_len)
|
||||||
|
|
||||||
|
|
||||||
|
@ -108,19 +124,23 @@ class MeaningDataset(Dataset):
|
||||||
seed=42,
|
seed=42,
|
||||||
data=None,
|
data=None,
|
||||||
length=None,
|
length=None,
|
||||||
|
mapping=None,
|
||||||
):
|
):
|
||||||
if data != None and length != None:
|
if data != None and length != None and mapping != None:
|
||||||
self.data = data
|
self.data = data
|
||||||
self.length = length
|
self.length = length
|
||||||
|
self.mapping = mapping
|
||||||
return
|
return
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
mm = MeaningMap(size=end, vocab_size=vocab_size, max_subitem=max_subitem) # 1048576
|
mm = MeaningMap(size=end, vocab_size=vocab_size, max_subitem=max_subitem) # 1048576
|
||||||
|
self.mapping = []
|
||||||
self.data = []
|
self.data = []
|
||||||
self.length = []
|
self.length = []
|
||||||
meanings = np.random.randint(start, end, size=(size))
|
meanings = np.random.randint(start, end, size=(size))
|
||||||
for m in meanings:
|
for m in meanings:
|
||||||
sq = mm.GetSequence(m)
|
sq = mm.get_sequence(m)
|
||||||
if len(sq) >= min_seq_len:
|
if len(sq) >= min_seq_len:
|
||||||
|
self.mapping.append(mm.get_mapping(m))
|
||||||
self.data.append(sq)
|
self.data.append(sq)
|
||||||
self.length.append(len(sq))
|
self.length.append(len(sq))
|
||||||
|
|
||||||
|
@ -146,7 +166,7 @@ class MeaningDataset(Dataset):
|
||||||
output["token_type_ids"] = torch.zeros(data.shape)
|
output["token_type_ids"] = torch.zeros(data.shape)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def GetBatch(self, index_list): # must equal sequence length
|
def get_batch(self, index_list): # must equal sequence length
|
||||||
data = [self.data[i] for i in index_list]
|
data = [self.data[i] for i in index_list]
|
||||||
output = {}
|
output = {}
|
||||||
data = torch.tensor(np.stack(data, axis=0)).long()
|
data = torch.tensor(np.stack(data, axis=0)).long()
|
||||||
|
@ -155,13 +175,17 @@ class MeaningDataset(Dataset):
|
||||||
output["token_type_ids"] = torch.zeros(data.shape)
|
output["token_type_ids"] = torch.zeros(data.shape)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def Split(self, ratio):
|
def get_mapping_batch(self, index_list):
|
||||||
|
return [self.mapping[i] for i in index_list]
|
||||||
|
|
||||||
|
def split(self, ratio):
|
||||||
l = len(self.data)
|
l = len(self.data)
|
||||||
middle = int(l * ratio)
|
middle = int(l * ratio)
|
||||||
d_shuffle = self.data.copy()
|
d_shuffle = self.data.copy()
|
||||||
l_shuffle = self.length.copy()
|
l_shuffle = self.length.copy()
|
||||||
md1 = MeaningDataset(data=d_shuffle[:middle], length=l_shuffle[:middle])
|
m_shuffle = self.mapping.copy()
|
||||||
md2 = MeaningDataset(data=d_shuffle[middle:], length=l_shuffle[middle:])
|
md1 = MeaningDataset(data=d_shuffle[:middle], length=l_shuffle[:middle], mapping=m_shuffle[:middle])
|
||||||
|
md2 = MeaningDataset(data=d_shuffle[middle:], length=l_shuffle[middle:], mapping=m_shuffle[middle:])
|
||||||
return md1, md2
|
return md1, md2
|
||||||
|
|
||||||
|
|
||||||
|
@ -195,6 +219,7 @@ class BatchGroupMeaningDataloader(Dataset):
|
||||||
batch = len(gs[l]) // batch_size
|
batch = len(gs[l]) // batch_size
|
||||||
new = gs[l][0 : batch * batch_size].reshape(batch, batch_size)
|
new = gs[l][0 : batch * batch_size].reshape(batch, batch_size)
|
||||||
index = np.concatenate((index, new), axis=0)
|
index = np.concatenate((index, new), axis=0)
|
||||||
|
|
||||||
if shuffle:
|
if shuffle:
|
||||||
index_shuffle = np.arange(0, index.shape[0])
|
index_shuffle = np.arange(0, index.shape[0])
|
||||||
np.random.shuffle(index_shuffle)
|
np.random.shuffle(index_shuffle)
|
||||||
|
@ -207,22 +232,26 @@ class BatchGroupMeaningDataloader(Dataset):
|
||||||
return len(self.indexBatch)
|
return len(self.indexBatch)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
# print("get idx" + str(idx))
|
return self.dataset.get_batch(self.indexBatch[idx])
|
||||||
return self.dataset.GetBatch(self.indexBatch[idx])
|
|
||||||
|
def mapping(self, idx):
|
||||||
|
return self.dataset.get_mapping_batch(self.indexBatch[idx])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
md = MeaningDataset(4096, 8100, size=1024)
|
md = MeaningDataset(1024, 115200, vocab_size=1024, size=1024)
|
||||||
train, val = md.Split(0.95)
|
train, val = md.split(0.95)
|
||||||
|
|
||||||
dl = BatchGroupMeaningDataloader(train, 32)
|
dl = BatchGroupMeaningDataloader(train, 2)
|
||||||
length = len(dl)
|
length = len(dl)
|
||||||
it = iter(dl)
|
it = iter(dl)
|
||||||
ne1 = next(it)
|
ne1 = next(it)
|
||||||
ne2 = next(it)
|
ne2 = next(it)
|
||||||
ne3 = next(it)
|
ne3 = next(it)
|
||||||
|
|
||||||
|
map = dl.mapping(0)
|
||||||
|
|
||||||
dl = DataLoader(
|
dl = DataLoader(
|
||||||
train,
|
train,
|
||||||
num_workers=1,
|
num_workers=1,
|
||||||
|
|
|
@ -54,7 +54,7 @@ if __name__ == "__main__":
|
||||||
end = start * level_ratio
|
end = start * level_ratio
|
||||||
size = int(vocab_size * (level_ratio**dataset_level))
|
size = int(vocab_size * (level_ratio**dataset_level))
|
||||||
raw_dataset = MeaningDataset(start, end, size, vocab_size, level_ratio)
|
raw_dataset = MeaningDataset(start, end, size, vocab_size, level_ratio)
|
||||||
train_dataset, val_dataset = raw_dataset.Split(0.9)
|
train_dataset, val_dataset = raw_dataset.split(0.9)
|
||||||
train_dataloader = BatchGroupMeaningDataloader(train_dataset, train_batch_size)
|
train_dataloader = BatchGroupMeaningDataloader(train_dataset, train_batch_size)
|
||||||
val_dataloader = BatchGroupMeaningDataloader(val_dataset, val_batch_size)
|
val_dataloader = BatchGroupMeaningDataloader(val_dataset, val_batch_size)
|
||||||
# it = iter(train_dataloader)
|
# it = iter(train_dataloader)
|
||||||
|
|
Loading…
Reference in New Issue