Fix dataset.

This commit is contained in:
Colin 2024-04-14 01:27:58 +08:00
parent 6791264987
commit 9926b893a4
1 changed files with 13 additions and 6 deletions

View File

@ -12,13 +12,17 @@ import copy
class MeaningMap:
def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, use_cache=True):
def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, min_subitem=1, use_cache=True):
assert size > 0 and vocab_size > 0 and max_subitem > 0 and min_subitem > 0, "Invalid input"
assert min_subitem <= max_subitem, "Invalid input"
self.size = size
self.vocab_size = vocab_size
self.max_subitem = max_subitem
self.min_subitem = min_subitem
path = "./data/"
file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem)
file = "structured_language_" + str(size) + "_" + str(vocab_size)
file += "_" + str(max_subitem) + "_" + str(min_subitem)
file = path + file + ".npz"
if not os.path.exists(path):
@ -47,7 +51,7 @@ class MeaningMap:
map = np.random.random((size, max_subitem))
mask_zero = map.copy()
mask_zero[:, 0] = 0.0
mask_zero[:, 0:min_subitem] = 0.0
mask_zero.sort(axis=1)
thre = np.random.random((size)).reshape(-1, 1).repeat(max_subitem, axis=1)
mask_zero = mask_zero > thre
@ -261,12 +265,13 @@ class MeaningDataset(Dataset):
size,
vocab_size,
max_subitem=10,
min_subitem=1,
min_seq_len=2,
seed=42,
use_cache=True,
):
np.random.seed(seed)
map = MeaningMap(size=end, vocab_size=vocab_size, max_subitem=max_subitem, use_cache=use_cache)
map = MeaningMap(end, vocab_size, max_subitem, min_subitem, use_cache=use_cache)
np.random.seed(seed)
self.mask_level = None
@ -381,7 +386,7 @@ class MeaningDataset(Dataset):
)
for i, l in enumerate(self.mask_level[1:]):
mask = mask & torch.tensor(
np.stack([self.get_seq_mask(idx, l, self.mask_idx[i]) for idx in idx_list], axis=0)
np.stack([self.get_seq_mask(idx, l, self.mask_idx[i + 1]) for idx in idx_list], axis=0)
)
return mask
else:
@ -529,7 +534,9 @@ if __name__ == "__main__":
), "False"
freq = md.token_frequency()
dl = BatchGroupMeaningDataloader(val, 1)
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, min_subitem=2, use_cache=False)
md.set_mask([0, 1], [0, -1])
dl = BatchGroupMeaningDataloader(md, 1)
length = len(dl)
it = iter(dl)
ne1 = next(it)