2024-03-13 19:41:02 +08:00
|
|
|
import os
|
|
|
|
import datasets
|
|
|
|
import torch
|
|
|
|
import math
|
|
|
|
import random
|
|
|
|
from itertools import chain
|
|
|
|
from typing import Dict, Tuple
|
|
|
|
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
class MeaningMap: # 16777216 1048576 8192
|
|
|
|
|
|
|
|
def __init__(self, size=1048576, vocab_size=4096, max_subitem=10):
|
|
|
|
self.size = size
|
|
|
|
self.vocab_size = vocab_size
|
|
|
|
self.max_subitem = max_subitem
|
|
|
|
|
|
|
|
file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem)
|
|
|
|
file_start = file + "_start" + ".npy"
|
|
|
|
file_len = file + "_len" + ".npy"
|
|
|
|
file_data = file + "_data" + ".npy"
|
|
|
|
|
|
|
|
if os.path.exists(file_start) and os.path.exists(file_len) and os.path.exists(file_data):
|
|
|
|
print("Load from disk cache: " + file)
|
|
|
|
self.ms_data = np.load(file_data)
|
|
|
|
self.ms_start = np.load(file_start)
|
|
|
|
self.ms_len = np.load(file_len)
|
|
|
|
return None
|
|
|
|
|
|
|
|
print("Disk cache miss, build new one.")
|
|
|
|
|
|
|
|
mm = np.empty((size, max_subitem), dtype=np.int32)
|
2024-03-15 11:16:13 +08:00
|
|
|
# total_level = int(math.log(size / vocab_size, max_subitem))
|
|
|
|
|
|
|
|
# start = [0]
|
|
|
|
# end = [vocab_size]
|
|
|
|
# shift = vocab_size
|
|
|
|
# for i in range(total_level):
|
|
|
|
# shift = end[-1]
|
|
|
|
# start.append(end[-1])
|
|
|
|
# end.append(shift * self.max_subitem)
|
|
|
|
# start.append(end[-1])
|
|
|
|
# end.append(size)
|
2024-03-13 19:41:02 +08:00
|
|
|
|
|
|
|
index = np.arange(0, size)
|
|
|
|
mm = np.random.random((size, max_subitem))
|
|
|
|
|
|
|
|
mask_zero = mm.copy()
|
|
|
|
mask_zero[:, 0] = 0.0
|
|
|
|
mask_zero.sort(axis=1)
|
|
|
|
thre = np.random.random((size)).reshape(-1, 1).repeat(max_subitem, axis=1)
|
|
|
|
mask_zero = mask_zero > thre
|
|
|
|
|
|
|
|
item_sum = mm.sum(axis=1)
|
|
|
|
scale = (index / item_sum).reshape(-1, 1).repeat(max_subitem, axis=1)
|
|
|
|
mm = mm * scale
|
|
|
|
mm[mask_zero] = 0
|
|
|
|
|
|
|
|
mm[:vocab_size, 0] = np.arange(0, vocab_size)
|
|
|
|
mm[:vocab_size, 1:] = 0
|
|
|
|
mm = mm.astype(np.int32)
|
|
|
|
|
|
|
|
ms = [] # meaning sequence
|
|
|
|
ms_start = [] # meaning sequence start
|
|
|
|
ms_len = [] # meaning sequence length
|
|
|
|
index = 0
|
|
|
|
for i in range(self.vocab_size):
|
|
|
|
ms.append([i])
|
|
|
|
ms_start.append(index)
|
|
|
|
ms_len.append(1)
|
|
|
|
index = index + 1
|
|
|
|
|
|
|
|
for i in range(self.vocab_size, size):
|
|
|
|
m = mm[i]
|
|
|
|
m = m[m > 0]
|
|
|
|
ma = []
|
|
|
|
for newm in m.tolist():
|
|
|
|
ma = ma + ms[newm]
|
|
|
|
ms.append(ma)
|
|
|
|
ms_start.append(index)
|
|
|
|
ms_len.append(len(ma))
|
|
|
|
index = index + len(ma)
|
|
|
|
|
|
|
|
ms_data = list(chain(*ms))
|
|
|
|
np.save(file_data, np.array(ms_data).astype(np.int32))
|
|
|
|
np.save(file_start, np.array(ms_start).astype(np.int32))
|
|
|
|
np.save(file_len, np.array(ms_len).astype(np.int32))
|
|
|
|
|
|
|
|
self.ms_data = ms_data
|
|
|
|
self.ms_start = ms_start
|
|
|
|
self.ms_len = ms_len
|
|
|
|
print("Disk cache build end.")
|
|
|
|
|
|
|
|
def GetSequence(self, meaning):
|
|
|
|
start = self.ms_start[meaning]
|
|
|
|
len = self.ms_len[meaning]
|
|
|
|
return self.ms_data[start : start + len]
|
|
|
|
|
|
|
|
|
|
|
|
class MeaningDataset(Dataset):
|
|
|
|
|
|
|
|
def __init__(self, start=131072, end=1048576, size=32768, vocab_size=4096, max_subitem=10, seed=42):
|
|
|
|
self.seed = seed
|
|
|
|
np.random.seed(seed)
|
|
|
|
self.size = size
|
|
|
|
self.mm = MeaningMap(size=end, vocab_size=vocab_size, max_subitem=max_subitem) # 1048576
|
|
|
|
self.data = []
|
|
|
|
meanings = np.random.randint(start, end, size=(size))
|
|
|
|
for m in meanings:
|
2024-03-15 11:16:13 +08:00
|
|
|
sq = self.mm.GetSequence(m)
|
|
|
|
if len(sq) > 1:
|
|
|
|
self.data.append(sq)
|
|
|
|
left = size - len(self.data)
|
|
|
|
while True:
|
|
|
|
if left <= 0:
|
|
|
|
break
|
|
|
|
index = np.random.randint(start, end)
|
|
|
|
sq = self.mm.GetSequence(index)
|
|
|
|
if len(sq) > 1:
|
|
|
|
self.data.append(sq)
|
|
|
|
left = left - 1
|
2024-03-13 19:41:02 +08:00
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return self.size
|
|
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
output = {}
|
2024-03-14 11:40:26 +08:00
|
|
|
data = torch.tensor(self.data[idx]).long()
|
2024-03-13 19:41:02 +08:00
|
|
|
output["input_ids"] = data
|
|
|
|
output["labels"] = data.clone()
|
|
|
|
output["token_type_ids"] = torch.zeros(data.shape)
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
|
|
md = MeaningDataset(4096, 4100, size=32768)
|
|
|
|
it = iter(md)
|
|
|
|
for i in range(10):
|
|
|
|
daf = next(it)["input_ids"].numpy().tolist()
|
|
|
|
|
|
|
|
print(daf)
|