Update inference and val dataset.
This commit is contained in:
parent
383c40afd7
commit
81f9e54ca3
|
@ -60,3 +60,44 @@ def InitDataset(config):
|
|||
)
|
||||
val_dataloader = BatchGroupMeaningDataloader(val_dataset, val_batch_size).dataloader(config.dataloader_works)
|
||||
return train_dataloader, val_dataloader
|
||||
|
||||
|
||||
def InitValDataset(config):
|
||||
|
||||
val_batch_size = config.val_batch_size
|
||||
num_proc = config.num_proc
|
||||
if config.dataset.name == "special":
|
||||
raw_dataset = SpecialDataset()
|
||||
train_dataset, val_dataset = random_split(raw_dataset, [0.95, 0.05])
|
||||
val_dataloader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=val_batch_size,
|
||||
num_workers=num_proc,
|
||||
persistent_workers=True,
|
||||
)
|
||||
return val_dataloader
|
||||
|
||||
if config.dataset.name == "meaning":
|
||||
c = config.dataset.meaning
|
||||
vocab = config.model_config.vocab_size
|
||||
start = vocab * (c.level_ratio**c.level)
|
||||
size = vocab * int((c.level_ratio**c.dataset_level))
|
||||
|
||||
path = "./data/"
|
||||
valfile = path + f"MeaningDataset_val_v{size}_s{start}_s{size}_lr{c.level_ratio}_ms{c.min_subitem}.pt"
|
||||
if not os.path.exists(path):
|
||||
os.mkdir(path)
|
||||
if os.path.exists(valfile):
|
||||
print(f"INFO: Load dataset from {valfile}")
|
||||
val_dataset = torch.load(valfile, weights_only=False)
|
||||
val_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
|
||||
print(f"INFO: Load dataset end")
|
||||
else:
|
||||
raw_dataset = MeaningDataset(start, start + size, vocab, None, c.level_ratio, c.min_subitem)
|
||||
raw_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
|
||||
train_dataset, val_dataset = raw_dataset.split(0.9)
|
||||
torch.save(val_dataset, valfile)
|
||||
print(f"INFO: Build and save dataset end")
|
||||
|
||||
val_dataloader = BatchGroupMeaningDataloader(val_dataset, val_batch_size).dataloader(config.dataloader_works)
|
||||
return val_dataloader
|
||||
|
|
|
@ -393,8 +393,8 @@ class MeaningDataset(Dataset):
|
|||
return freq
|
||||
|
||||
def get_seq_mask(self, idx, level, index):
|
||||
assert index < 15, "index must < 15"
|
||||
assert level < 8, "level must < 8"
|
||||
# assert index < 15, "index must < 15"
|
||||
# assert level < 8, "level must < 8"
|
||||
rank_idx = (self.rank_idx[idx] >> (4 * level)).astype(np.int32) & 0xF
|
||||
rank_all = (self.rank_all[idx] >> (4 * level)).astype(np.int32) & 0xF
|
||||
return rank_idx == (rank_all + index if index < 0 else index)
|
||||
|
|
|
@ -4,6 +4,7 @@ import torch
|
|||
from model.qwen_module import QwenModule
|
||||
from model.modeling_wit import QwenRunner
|
||||
from model.tokenization_qwen import QWenTokenizer
|
||||
import numpy as np
|
||||
|
||||
import configuration
|
||||
import dataset.dataset as ds
|
||||
|
@ -37,20 +38,19 @@ if __name__ == "__main__":
|
|||
|
||||
torch.manual_seed(conf.seed)
|
||||
|
||||
qwen = QwenModule.load_from_checkpoint(checkpoint_path = "log/bigger/version_1/checkpoints/epoch=26-step=27891.ckpt")
|
||||
checkpoint_path = "log/bigger/version_1/checkpoints/epoch=23-step=24792.ckpt"
|
||||
qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
||||
qwen.eval()
|
||||
|
||||
runner = QwenRunner(qwen.llm)
|
||||
|
||||
train_dataloader, val_dataloader = ds.InitDataset(conf)
|
||||
it = iter(val_dataloader)
|
||||
batch = next(it)
|
||||
val = ds.InitValDataset(conf).dataset
|
||||
data = val.dataset
|
||||
item = data.get_token(0)
|
||||
print(data.print_tree(0))
|
||||
|
||||
fdsafd = batch["input_ids"].numpy()
|
||||
|
||||
|
||||
print(batch["input_ids"].numpy())
|
||||
print(batch["input_ids"][0:1,:-1].numpy())
|
||||
next_token = runner.ChatToken(batch["input_ids"][0:1,:-1].cuda())
|
||||
batch = torch.tensor([item[:-1]], dtype=torch.int64)
|
||||
batch = batch.cuda()
|
||||
print(item)
|
||||
next_token = runner.ChatToken(batch)
|
||||
print(next_token.detach().cpu().numpy())
|
||||
|
||||
|
|
Loading…
Reference in New Issue