Dump dataset to single file.

This commit is contained in:
Colin 2024-04-07 17:32:21 +08:00
parent 9d3b9a210a
commit 7560166b76
1 changed files with 202 additions and 14 deletions

View File

@ -19,15 +19,14 @@ class MeaningMap:
path = "./data/"
file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem)
file = path + file
file_slhwm = file + "_slhwm" + ".npy"
file_dli = file + "_dli" + ".npy"
if not os.path.exists(path):
os.mkdir(path)
if os.path.exists(file_slhwm) and os.path.exists(file_dli) and use_cache:
if os.path.exists(file) and use_cache:
print("Load from disk cache: " + file)
slhwm = np.load(file_slhwm)
dli = np.load(file_dli)
loaded = np.load(file)
slhwm = loaded["slhwm"]
dli = loaded["dli"]
self.ms_map = slhwm[:, 4:]
self.ms_data = dli[:, 0]
self.ms_start = slhwm[:, 0]
@ -150,9 +149,7 @@ class MeaningMap:
axis=1,
)
dli = np.stack((ms_data, ms_level, ms_idx.astype(np.int32)), axis=1)
np.save(file_slhwm, slhwm)
np.save(file_dli, dli)
np.savez(file, slhwm=slhwm, dli=dli)
self.ms_map = map # ms_map[i] = [sub(i),sub(i),sub(i),sub(i)...sub(i)]
self.ms_data = ms_data # map[i]=ms_data[ms_start[i]:ms_start[i]+ms_len[i]]
@ -390,13 +387,206 @@ class BatchGroupMeaningDataloader(Dataset):
if __name__ == "__main__":
md = MeaningDataset(100000, 115200, vocab_size=128, size=1024, use_cache=False)
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=True)
train, val = md.split(0.95)
fdaf = md.__getitem__(920)
print(md.print_tree(920))
print(md.idx[920])
fdasfe = MeaningDataset.get_seq_mask(md.idx[920], 1, 1)
print(fdasfe)
mask = MeaningDataset.get_seq_mask(md.idx[920], 1, 1)
print(mask)
assert mask == [
False,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
False,
False,
False,
False,
False,
False,
False,
False,
True,
True,
True,
True,
True,
True,
True,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
True,
True,
False,
False,
False,
False,
False,
False,
True,
True,
True,
True,
True,
True,
True,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
True,
True,
True,
True,
True,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
True,
True,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
True,
True,
True,
True,
True,
True,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
], "False"
freq = md.token_frequency()
dl = BatchGroupMeaningDataloader(train, 2)
@ -422,6 +612,4 @@ if __name__ == "__main__":
ne3 = next(it)
for i in range(10):
daf = next(it)["input_ids"].numpy().tolist()
print(daf)
print(next(it)["input_ids"].numpy().tolist())