diff --git a/wit/meaning_dataset.py b/wit/meaning_dataset.py index 4eb658a..d8d18d7 100644 --- a/wit/meaning_dataset.py +++ b/wit/meaning_dataset.py @@ -19,15 +19,14 @@ class MeaningMap: path = "./data/" file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem) file = path + file - file_slhwm = file + "_slhwm" + ".npy" - file_dli = file + "_dli" + ".npy" if not os.path.exists(path): os.mkdir(path) - if os.path.exists(file_slhwm) and os.path.exists(file_dli) and use_cache: + if os.path.exists(file) and use_cache: print("Load from disk cache: " + file) - slhwm = np.load(file_slhwm) - dli = np.load(file_dli) + loaded = np.load(file) + slhwm = loaded["slhwm"] + dli = loaded["dli"] self.ms_map = slhwm[:, 4:] self.ms_data = dli[:, 0] self.ms_start = slhwm[:, 0] @@ -150,9 +149,7 @@ class MeaningMap: axis=1, ) dli = np.stack((ms_data, ms_level, ms_idx.astype(np.int32)), axis=1) - - np.save(file_slhwm, slhwm) - np.save(file_dli, dli) + np.savez(file, slhwm=slhwm, dli=dli) self.ms_map = map # ms_map[i] = [sub(i),sub(i),sub(i),sub(i)...sub(i)] self.ms_data = ms_data # map[i]=ms_data[ms_start[i]:ms_start[i]+ms_len[i]] @@ -390,13 +387,206 @@ class BatchGroupMeaningDataloader(Dataset): if __name__ == "__main__": - md = MeaningDataset(100000, 115200, vocab_size=128, size=1024, use_cache=False) + md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=True) train, val = md.split(0.95) fdaf = md.__getitem__(920) print(md.print_tree(920)) print(md.idx[920]) - fdasfe = MeaningDataset.get_seq_mask(md.idx[920], 1, 1) - print(fdasfe) + mask = MeaningDataset.get_seq_mask(md.idx[920], 1, 1) + print(mask) + assert mask == [ + False, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + True, + True, + False, + False, + False, + False, + False, + False, + True, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + True, + True, + True, + True, + True, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], "False" freq = md.token_frequency() dl = BatchGroupMeaningDataloader(train, 2) @@ -422,6 +612,4 @@ if __name__ == "__main__": ne3 = next(it) for i in range(10): - daf = next(it)["input_ids"].numpy().tolist() - - print(daf) + print(next(it)["input_ids"].numpy().tolist())