From 81f9e54ca3b3fe55fac48aaf54cfbb85f185a3fe Mon Sep 17 00:00:00 2001 From: Colin Date: Fri, 21 Feb 2025 17:28:21 +0800 Subject: [PATCH] Update inference and val dataset. --- wit/dataset/dataset.py | 41 ++++++++++++++++++++++++++++++++++ wit/dataset/meaning_dataset.py | 4 ++-- wit/inference.py | 22 +++++++++--------- 3 files changed, 54 insertions(+), 13 deletions(-) diff --git a/wit/dataset/dataset.py b/wit/dataset/dataset.py index 3ed5496..bdf6e85 100644 --- a/wit/dataset/dataset.py +++ b/wit/dataset/dataset.py @@ -60,3 +60,44 @@ def InitDataset(config): ) val_dataloader = BatchGroupMeaningDataloader(val_dataset, val_batch_size).dataloader(config.dataloader_works) return train_dataloader, val_dataloader + + +def InitValDataset(config): + + val_batch_size = config.val_batch_size + num_proc = config.num_proc + if config.dataset.name == "special": + raw_dataset = SpecialDataset() + train_dataset, val_dataset = random_split(raw_dataset, [0.95, 0.05]) + val_dataloader = DataLoader( + val_dataset, + batch_size=val_batch_size, + num_workers=num_proc, + persistent_workers=True, + ) + return val_dataloader + + if config.dataset.name == "meaning": + c = config.dataset.meaning + vocab = config.model_config.vocab_size + start = vocab * (c.level_ratio**c.level) + size = vocab * int((c.level_ratio**c.dataset_level)) + + path = "./data/" + valfile = path + f"MeaningDataset_val_v{size}_s{start}_s{size}_lr{c.level_ratio}_ms{c.min_subitem}.pt" + if not os.path.exists(path): + os.mkdir(path) + if os.path.exists(valfile): + print(f"INFO: Load dataset from {valfile}") + val_dataset = torch.load(valfile, weights_only=False) + val_dataset.set_mask(c.val_mask_level, c.val_mask_idx) + print(f"INFO: Load dataset end") + else: + raw_dataset = MeaningDataset(start, start + size, vocab, None, c.level_ratio, c.min_subitem) + raw_dataset.set_mask(c.val_mask_level, c.val_mask_idx) + train_dataset, val_dataset = raw_dataset.split(0.9) + torch.save(val_dataset, valfile) + print(f"INFO: Build and save dataset end") + + val_dataloader = BatchGroupMeaningDataloader(val_dataset, val_batch_size).dataloader(config.dataloader_works) + return val_dataloader diff --git a/wit/dataset/meaning_dataset.py b/wit/dataset/meaning_dataset.py index 2289bdb..56c5400 100644 --- a/wit/dataset/meaning_dataset.py +++ b/wit/dataset/meaning_dataset.py @@ -393,8 +393,8 @@ class MeaningDataset(Dataset): return freq def get_seq_mask(self, idx, level, index): - assert index < 15, "index must < 15" - assert level < 8, "level must < 8" + # assert index < 15, "index must < 15" + # assert level < 8, "level must < 8" rank_idx = (self.rank_idx[idx] >> (4 * level)).astype(np.int32) & 0xF rank_all = (self.rank_all[idx] >> (4 * level)).astype(np.int32) & 0xF return rank_idx == (rank_all + index if index < 0 else index) diff --git a/wit/inference.py b/wit/inference.py index fcd94d5..2f2b0b7 100644 --- a/wit/inference.py +++ b/wit/inference.py @@ -4,6 +4,7 @@ import torch from model.qwen_module import QwenModule from model.modeling_wit import QwenRunner from model.tokenization_qwen import QWenTokenizer +import numpy as np import configuration import dataset.dataset as ds @@ -37,20 +38,19 @@ if __name__ == "__main__": torch.manual_seed(conf.seed) - qwen = QwenModule.load_from_checkpoint(checkpoint_path = "log/bigger/version_1/checkpoints/epoch=26-step=27891.ckpt") + checkpoint_path = "log/bigger/version_1/checkpoints/epoch=23-step=24792.ckpt" + qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path) qwen.eval() runner = QwenRunner(qwen.llm) - train_dataloader, val_dataloader = ds.InitDataset(conf) - it = iter(val_dataloader) - batch = next(it) + val = ds.InitValDataset(conf).dataset + data = val.dataset + item = data.get_token(0) + print(data.print_tree(0)) - fdsafd = batch["input_ids"].numpy() - - - print(batch["input_ids"].numpy()) - print(batch["input_ids"][0:1,:-1].numpy()) - next_token = runner.ChatToken(batch["input_ids"][0:1,:-1].cuda()) + batch = torch.tensor([item[:-1]], dtype=torch.int64) + batch = batch.cuda() + print(item) + next_token = runner.ChatToken(batch) print(next_token.detach().cpu().numpy()) -