Fix dataset.
This commit is contained in:
parent
6791264987
commit
9926b893a4
|
@ -12,13 +12,17 @@ import copy
|
|||
|
||||
|
||||
class MeaningMap:
|
||||
def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, use_cache=True):
|
||||
def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, min_subitem=1, use_cache=True):
|
||||
assert size > 0 and vocab_size > 0 and max_subitem > 0 and min_subitem > 0, "Invalid input"
|
||||
assert min_subitem <= max_subitem, "Invalid input"
|
||||
self.size = size
|
||||
self.vocab_size = vocab_size
|
||||
self.max_subitem = max_subitem
|
||||
self.min_subitem = min_subitem
|
||||
|
||||
path = "./data/"
|
||||
file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem)
|
||||
file = "structured_language_" + str(size) + "_" + str(vocab_size)
|
||||
file += "_" + str(max_subitem) + "_" + str(min_subitem)
|
||||
file = path + file + ".npz"
|
||||
|
||||
if not os.path.exists(path):
|
||||
|
@ -47,7 +51,7 @@ class MeaningMap:
|
|||
map = np.random.random((size, max_subitem))
|
||||
|
||||
mask_zero = map.copy()
|
||||
mask_zero[:, 0] = 0.0
|
||||
mask_zero[:, 0:min_subitem] = 0.0
|
||||
mask_zero.sort(axis=1)
|
||||
thre = np.random.random((size)).reshape(-1, 1).repeat(max_subitem, axis=1)
|
||||
mask_zero = mask_zero > thre
|
||||
|
@ -261,12 +265,13 @@ class MeaningDataset(Dataset):
|
|||
size,
|
||||
vocab_size,
|
||||
max_subitem=10,
|
||||
min_subitem=1,
|
||||
min_seq_len=2,
|
||||
seed=42,
|
||||
use_cache=True,
|
||||
):
|
||||
np.random.seed(seed)
|
||||
map = MeaningMap(size=end, vocab_size=vocab_size, max_subitem=max_subitem, use_cache=use_cache)
|
||||
map = MeaningMap(end, vocab_size, max_subitem, min_subitem, use_cache=use_cache)
|
||||
np.random.seed(seed)
|
||||
|
||||
self.mask_level = None
|
||||
|
@ -381,7 +386,7 @@ class MeaningDataset(Dataset):
|
|||
)
|
||||
for i, l in enumerate(self.mask_level[1:]):
|
||||
mask = mask & torch.tensor(
|
||||
np.stack([self.get_seq_mask(idx, l, self.mask_idx[i]) for idx in idx_list], axis=0)
|
||||
np.stack([self.get_seq_mask(idx, l, self.mask_idx[i + 1]) for idx in idx_list], axis=0)
|
||||
)
|
||||
return mask
|
||||
else:
|
||||
|
@ -529,7 +534,9 @@ if __name__ == "__main__":
|
|||
), "False"
|
||||
freq = md.token_frequency()
|
||||
|
||||
dl = BatchGroupMeaningDataloader(val, 1)
|
||||
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, min_subitem=2, use_cache=False)
|
||||
md.set_mask([0, 1], [0, -1])
|
||||
dl = BatchGroupMeaningDataloader(md, 1)
|
||||
length = len(dl)
|
||||
it = iter(dl)
|
||||
ne1 = next(it)
|
||||
|
|
Loading…
Reference in New Issue