Dump dataset to single file.
This commit is contained in:
parent
9d3b9a210a
commit
7560166b76
|
@ -19,15 +19,14 @@ class MeaningMap:
|
|||
path = "./data/"
|
||||
file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem)
|
||||
file = path + file
|
||||
file_slhwm = file + "_slhwm" + ".npy"
|
||||
file_dli = file + "_dli" + ".npy"
|
||||
|
||||
if not os.path.exists(path):
|
||||
os.mkdir(path)
|
||||
if os.path.exists(file_slhwm) and os.path.exists(file_dli) and use_cache:
|
||||
if os.path.exists(file) and use_cache:
|
||||
print("Load from disk cache: " + file)
|
||||
slhwm = np.load(file_slhwm)
|
||||
dli = np.load(file_dli)
|
||||
loaded = np.load(file)
|
||||
slhwm = loaded["slhwm"]
|
||||
dli = loaded["dli"]
|
||||
self.ms_map = slhwm[:, 4:]
|
||||
self.ms_data = dli[:, 0]
|
||||
self.ms_start = slhwm[:, 0]
|
||||
|
@ -150,9 +149,7 @@ class MeaningMap:
|
|||
axis=1,
|
||||
)
|
||||
dli = np.stack((ms_data, ms_level, ms_idx.astype(np.int32)), axis=1)
|
||||
|
||||
np.save(file_slhwm, slhwm)
|
||||
np.save(file_dli, dli)
|
||||
np.savez(file, slhwm=slhwm, dli=dli)
|
||||
|
||||
self.ms_map = map # ms_map[i] = [sub(i),sub(i),sub(i),sub(i)...sub(i)]
|
||||
self.ms_data = ms_data # map[i]=ms_data[ms_start[i]:ms_start[i]+ms_len[i]]
|
||||
|
@ -390,13 +387,206 @@ class BatchGroupMeaningDataloader(Dataset):
|
|||
|
||||
if __name__ == "__main__":
|
||||
|
||||
md = MeaningDataset(100000, 115200, vocab_size=128, size=1024, use_cache=False)
|
||||
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=True)
|
||||
train, val = md.split(0.95)
|
||||
fdaf = md.__getitem__(920)
|
||||
print(md.print_tree(920))
|
||||
print(md.idx[920])
|
||||
fdasfe = MeaningDataset.get_seq_mask(md.idx[920], 1, 1)
|
||||
print(fdasfe)
|
||||
mask = MeaningDataset.get_seq_mask(md.idx[920], 1, 1)
|
||||
print(mask)
|
||||
assert mask == [
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
], "False"
|
||||
freq = md.token_frequency()
|
||||
|
||||
dl = BatchGroupMeaningDataloader(train, 2)
|
||||
|
@ -422,6 +612,4 @@ if __name__ == "__main__":
|
|||
ne3 = next(it)
|
||||
|
||||
for i in range(10):
|
||||
daf = next(it)["input_ids"].numpy().tolist()
|
||||
|
||||
print(daf)
|
||||
print(next(it)["input_ids"].numpy().tolist())
|
||||
|
|
Loading…
Reference in New Issue