Dump dataset to single file.
This commit is contained in:
parent
9d3b9a210a
commit
7560166b76
|
@ -19,15 +19,14 @@ class MeaningMap:
|
||||||
path = "./data/"
|
path = "./data/"
|
||||||
file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem)
|
file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem)
|
||||||
file = path + file
|
file = path + file
|
||||||
file_slhwm = file + "_slhwm" + ".npy"
|
|
||||||
file_dli = file + "_dli" + ".npy"
|
|
||||||
|
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
os.mkdir(path)
|
os.mkdir(path)
|
||||||
if os.path.exists(file_slhwm) and os.path.exists(file_dli) and use_cache:
|
if os.path.exists(file) and use_cache:
|
||||||
print("Load from disk cache: " + file)
|
print("Load from disk cache: " + file)
|
||||||
slhwm = np.load(file_slhwm)
|
loaded = np.load(file)
|
||||||
dli = np.load(file_dli)
|
slhwm = loaded["slhwm"]
|
||||||
|
dli = loaded["dli"]
|
||||||
self.ms_map = slhwm[:, 4:]
|
self.ms_map = slhwm[:, 4:]
|
||||||
self.ms_data = dli[:, 0]
|
self.ms_data = dli[:, 0]
|
||||||
self.ms_start = slhwm[:, 0]
|
self.ms_start = slhwm[:, 0]
|
||||||
|
@ -150,9 +149,7 @@ class MeaningMap:
|
||||||
axis=1,
|
axis=1,
|
||||||
)
|
)
|
||||||
dli = np.stack((ms_data, ms_level, ms_idx.astype(np.int32)), axis=1)
|
dli = np.stack((ms_data, ms_level, ms_idx.astype(np.int32)), axis=1)
|
||||||
|
np.savez(file, slhwm=slhwm, dli=dli)
|
||||||
np.save(file_slhwm, slhwm)
|
|
||||||
np.save(file_dli, dli)
|
|
||||||
|
|
||||||
self.ms_map = map # ms_map[i] = [sub(i),sub(i),sub(i),sub(i)...sub(i)]
|
self.ms_map = map # ms_map[i] = [sub(i),sub(i),sub(i),sub(i)...sub(i)]
|
||||||
self.ms_data = ms_data # map[i]=ms_data[ms_start[i]:ms_start[i]+ms_len[i]]
|
self.ms_data = ms_data # map[i]=ms_data[ms_start[i]:ms_start[i]+ms_len[i]]
|
||||||
|
@ -390,13 +387,206 @@ class BatchGroupMeaningDataloader(Dataset):
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
md = MeaningDataset(100000, 115200, vocab_size=128, size=1024, use_cache=False)
|
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=True)
|
||||||
train, val = md.split(0.95)
|
train, val = md.split(0.95)
|
||||||
fdaf = md.__getitem__(920)
|
fdaf = md.__getitem__(920)
|
||||||
print(md.print_tree(920))
|
print(md.print_tree(920))
|
||||||
print(md.idx[920])
|
print(md.idx[920])
|
||||||
fdasfe = MeaningDataset.get_seq_mask(md.idx[920], 1, 1)
|
mask = MeaningDataset.get_seq_mask(md.idx[920], 1, 1)
|
||||||
print(fdasfe)
|
print(mask)
|
||||||
|
assert mask == [
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
], "False"
|
||||||
freq = md.token_frequency()
|
freq = md.token_frequency()
|
||||||
|
|
||||||
dl = BatchGroupMeaningDataloader(train, 2)
|
dl = BatchGroupMeaningDataloader(train, 2)
|
||||||
|
@ -422,6 +612,4 @@ if __name__ == "__main__":
|
||||||
ne3 = next(it)
|
ne3 = next(it)
|
||||||
|
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
daf = next(it)["input_ids"].numpy().tolist()
|
print(next(it)["input_ids"].numpy().tolist())
|
||||||
|
|
||||||
print(daf)
|
|
||||||
|
|
Loading…
Reference in New Issue