Update inference and val dataset.

This commit is contained in:
Colin 2025-02-21 17:28:21 +08:00
parent 383c40afd7
commit 81f9e54ca3
3 changed files with 54 additions and 13 deletions

View File

@ -60,3 +60,44 @@ def InitDataset(config):
) )
val_dataloader = BatchGroupMeaningDataloader(val_dataset, val_batch_size).dataloader(config.dataloader_works) val_dataloader = BatchGroupMeaningDataloader(val_dataset, val_batch_size).dataloader(config.dataloader_works)
return train_dataloader, val_dataloader return train_dataloader, val_dataloader
def InitValDataset(config):
val_batch_size = config.val_batch_size
num_proc = config.num_proc
if config.dataset.name == "special":
raw_dataset = SpecialDataset()
train_dataset, val_dataset = random_split(raw_dataset, [0.95, 0.05])
val_dataloader = DataLoader(
val_dataset,
batch_size=val_batch_size,
num_workers=num_proc,
persistent_workers=True,
)
return val_dataloader
if config.dataset.name == "meaning":
c = config.dataset.meaning
vocab = config.model_config.vocab_size
start = vocab * (c.level_ratio**c.level)
size = vocab * int((c.level_ratio**c.dataset_level))
path = "./data/"
valfile = path + f"MeaningDataset_val_v{size}_s{start}_s{size}_lr{c.level_ratio}_ms{c.min_subitem}.pt"
if not os.path.exists(path):
os.mkdir(path)
if os.path.exists(valfile):
print(f"INFO: Load dataset from {valfile}")
val_dataset = torch.load(valfile, weights_only=False)
val_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
print(f"INFO: Load dataset end")
else:
raw_dataset = MeaningDataset(start, start + size, vocab, None, c.level_ratio, c.min_subitem)
raw_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
train_dataset, val_dataset = raw_dataset.split(0.9)
torch.save(val_dataset, valfile)
print(f"INFO: Build and save dataset end")
val_dataloader = BatchGroupMeaningDataloader(val_dataset, val_batch_size).dataloader(config.dataloader_works)
return val_dataloader

View File

@ -393,8 +393,8 @@ class MeaningDataset(Dataset):
return freq return freq
def get_seq_mask(self, idx, level, index): def get_seq_mask(self, idx, level, index):
assert index < 15, "index must < 15" # assert index < 15, "index must < 15"
assert level < 8, "level must < 8" # assert level < 8, "level must < 8"
rank_idx = (self.rank_idx[idx] >> (4 * level)).astype(np.int32) & 0xF rank_idx = (self.rank_idx[idx] >> (4 * level)).astype(np.int32) & 0xF
rank_all = (self.rank_all[idx] >> (4 * level)).astype(np.int32) & 0xF rank_all = (self.rank_all[idx] >> (4 * level)).astype(np.int32) & 0xF
return rank_idx == (rank_all + index if index < 0 else index) return rank_idx == (rank_all + index if index < 0 else index)

View File

@ -4,6 +4,7 @@ import torch
from model.qwen_module import QwenModule from model.qwen_module import QwenModule
from model.modeling_wit import QwenRunner from model.modeling_wit import QwenRunner
from model.tokenization_qwen import QWenTokenizer from model.tokenization_qwen import QWenTokenizer
import numpy as np
import configuration import configuration
import dataset.dataset as ds import dataset.dataset as ds
@ -37,20 +38,19 @@ if __name__ == "__main__":
torch.manual_seed(conf.seed) torch.manual_seed(conf.seed)
qwen = QwenModule.load_from_checkpoint(checkpoint_path = "log/bigger/version_1/checkpoints/epoch=26-step=27891.ckpt") checkpoint_path = "log/bigger/version_1/checkpoints/epoch=23-step=24792.ckpt"
qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
qwen.eval() qwen.eval()
runner = QwenRunner(qwen.llm) runner = QwenRunner(qwen.llm)
train_dataloader, val_dataloader = ds.InitDataset(conf) val = ds.InitValDataset(conf).dataset
it = iter(val_dataloader) data = val.dataset
batch = next(it) item = data.get_token(0)
print(data.print_tree(0))
fdsafd = batch["input_ids"].numpy() batch = torch.tensor([item[:-1]], dtype=torch.int64)
batch = batch.cuda()
print(item)
print(batch["input_ids"].numpy()) next_token = runner.ChatToken(batch)
print(batch["input_ids"][0:1,:-1].numpy())
next_token = runner.ChatToken(batch["input_ids"][0:1,:-1].cuda())
print(next_token.detach().cpu().numpy()) print(next_token.detach().cpu().numpy())