Fix dataset.
This commit is contained in:
parent
6791264987
commit
9926b893a4
|
@ -12,13 +12,17 @@ import copy
|
||||||
|
|
||||||
|
|
||||||
class MeaningMap:
|
class MeaningMap:
|
||||||
def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, use_cache=True):
|
def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, min_subitem=1, use_cache=True):
|
||||||
|
assert size > 0 and vocab_size > 0 and max_subitem > 0 and min_subitem > 0, "Invalid input"
|
||||||
|
assert min_subitem <= max_subitem, "Invalid input"
|
||||||
self.size = size
|
self.size = size
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.max_subitem = max_subitem
|
self.max_subitem = max_subitem
|
||||||
|
self.min_subitem = min_subitem
|
||||||
|
|
||||||
path = "./data/"
|
path = "./data/"
|
||||||
file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem)
|
file = "structured_language_" + str(size) + "_" + str(vocab_size)
|
||||||
|
file += "_" + str(max_subitem) + "_" + str(min_subitem)
|
||||||
file = path + file + ".npz"
|
file = path + file + ".npz"
|
||||||
|
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
|
@ -47,7 +51,7 @@ class MeaningMap:
|
||||||
map = np.random.random((size, max_subitem))
|
map = np.random.random((size, max_subitem))
|
||||||
|
|
||||||
mask_zero = map.copy()
|
mask_zero = map.copy()
|
||||||
mask_zero[:, 0] = 0.0
|
mask_zero[:, 0:min_subitem] = 0.0
|
||||||
mask_zero.sort(axis=1)
|
mask_zero.sort(axis=1)
|
||||||
thre = np.random.random((size)).reshape(-1, 1).repeat(max_subitem, axis=1)
|
thre = np.random.random((size)).reshape(-1, 1).repeat(max_subitem, axis=1)
|
||||||
mask_zero = mask_zero > thre
|
mask_zero = mask_zero > thre
|
||||||
|
@ -261,12 +265,13 @@ class MeaningDataset(Dataset):
|
||||||
size,
|
size,
|
||||||
vocab_size,
|
vocab_size,
|
||||||
max_subitem=10,
|
max_subitem=10,
|
||||||
|
min_subitem=1,
|
||||||
min_seq_len=2,
|
min_seq_len=2,
|
||||||
seed=42,
|
seed=42,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
):
|
):
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
map = MeaningMap(size=end, vocab_size=vocab_size, max_subitem=max_subitem, use_cache=use_cache)
|
map = MeaningMap(end, vocab_size, max_subitem, min_subitem, use_cache=use_cache)
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
|
|
||||||
self.mask_level = None
|
self.mask_level = None
|
||||||
|
@ -381,7 +386,7 @@ class MeaningDataset(Dataset):
|
||||||
)
|
)
|
||||||
for i, l in enumerate(self.mask_level[1:]):
|
for i, l in enumerate(self.mask_level[1:]):
|
||||||
mask = mask & torch.tensor(
|
mask = mask & torch.tensor(
|
||||||
np.stack([self.get_seq_mask(idx, l, self.mask_idx[i]) for idx in idx_list], axis=0)
|
np.stack([self.get_seq_mask(idx, l, self.mask_idx[i + 1]) for idx in idx_list], axis=0)
|
||||||
)
|
)
|
||||||
return mask
|
return mask
|
||||||
else:
|
else:
|
||||||
|
@ -529,7 +534,9 @@ if __name__ == "__main__":
|
||||||
), "False"
|
), "False"
|
||||||
freq = md.token_frequency()
|
freq = md.token_frequency()
|
||||||
|
|
||||||
dl = BatchGroupMeaningDataloader(val, 1)
|
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, min_subitem=2, use_cache=False)
|
||||||
|
md.set_mask([0, 1], [0, -1])
|
||||||
|
dl = BatchGroupMeaningDataloader(md, 1)
|
||||||
length = len(dl)
|
length = len(dl)
|
||||||
it = iter(dl)
|
it = iter(dl)
|
||||||
ne1 = next(it)
|
ne1 = next(it)
|
||||||
|
|
Loading…
Reference in New Issue