Update inference and val dataset.
This commit is contained in:
parent
383c40afd7
commit
81f9e54ca3
|
@ -60,3 +60,44 @@ def InitDataset(config):
|
||||||
)
|
)
|
||||||
val_dataloader = BatchGroupMeaningDataloader(val_dataset, val_batch_size).dataloader(config.dataloader_works)
|
val_dataloader = BatchGroupMeaningDataloader(val_dataset, val_batch_size).dataloader(config.dataloader_works)
|
||||||
return train_dataloader, val_dataloader
|
return train_dataloader, val_dataloader
|
||||||
|
|
||||||
|
|
||||||
|
def InitValDataset(config):
|
||||||
|
|
||||||
|
val_batch_size = config.val_batch_size
|
||||||
|
num_proc = config.num_proc
|
||||||
|
if config.dataset.name == "special":
|
||||||
|
raw_dataset = SpecialDataset()
|
||||||
|
train_dataset, val_dataset = random_split(raw_dataset, [0.95, 0.05])
|
||||||
|
val_dataloader = DataLoader(
|
||||||
|
val_dataset,
|
||||||
|
batch_size=val_batch_size,
|
||||||
|
num_workers=num_proc,
|
||||||
|
persistent_workers=True,
|
||||||
|
)
|
||||||
|
return val_dataloader
|
||||||
|
|
||||||
|
if config.dataset.name == "meaning":
|
||||||
|
c = config.dataset.meaning
|
||||||
|
vocab = config.model_config.vocab_size
|
||||||
|
start = vocab * (c.level_ratio**c.level)
|
||||||
|
size = vocab * int((c.level_ratio**c.dataset_level))
|
||||||
|
|
||||||
|
path = "./data/"
|
||||||
|
valfile = path + f"MeaningDataset_val_v{size}_s{start}_s{size}_lr{c.level_ratio}_ms{c.min_subitem}.pt"
|
||||||
|
if not os.path.exists(path):
|
||||||
|
os.mkdir(path)
|
||||||
|
if os.path.exists(valfile):
|
||||||
|
print(f"INFO: Load dataset from {valfile}")
|
||||||
|
val_dataset = torch.load(valfile, weights_only=False)
|
||||||
|
val_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
|
||||||
|
print(f"INFO: Load dataset end")
|
||||||
|
else:
|
||||||
|
raw_dataset = MeaningDataset(start, start + size, vocab, None, c.level_ratio, c.min_subitem)
|
||||||
|
raw_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
|
||||||
|
train_dataset, val_dataset = raw_dataset.split(0.9)
|
||||||
|
torch.save(val_dataset, valfile)
|
||||||
|
print(f"INFO: Build and save dataset end")
|
||||||
|
|
||||||
|
val_dataloader = BatchGroupMeaningDataloader(val_dataset, val_batch_size).dataloader(config.dataloader_works)
|
||||||
|
return val_dataloader
|
||||||
|
|
|
@ -393,8 +393,8 @@ class MeaningDataset(Dataset):
|
||||||
return freq
|
return freq
|
||||||
|
|
||||||
def get_seq_mask(self, idx, level, index):
|
def get_seq_mask(self, idx, level, index):
|
||||||
assert index < 15, "index must < 15"
|
# assert index < 15, "index must < 15"
|
||||||
assert level < 8, "level must < 8"
|
# assert level < 8, "level must < 8"
|
||||||
rank_idx = (self.rank_idx[idx] >> (4 * level)).astype(np.int32) & 0xF
|
rank_idx = (self.rank_idx[idx] >> (4 * level)).astype(np.int32) & 0xF
|
||||||
rank_all = (self.rank_all[idx] >> (4 * level)).astype(np.int32) & 0xF
|
rank_all = (self.rank_all[idx] >> (4 * level)).astype(np.int32) & 0xF
|
||||||
return rank_idx == (rank_all + index if index < 0 else index)
|
return rank_idx == (rank_all + index if index < 0 else index)
|
||||||
|
|
|
@ -4,6 +4,7 @@ import torch
|
||||||
from model.qwen_module import QwenModule
|
from model.qwen_module import QwenModule
|
||||||
from model.modeling_wit import QwenRunner
|
from model.modeling_wit import QwenRunner
|
||||||
from model.tokenization_qwen import QWenTokenizer
|
from model.tokenization_qwen import QWenTokenizer
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
import configuration
|
import configuration
|
||||||
import dataset.dataset as ds
|
import dataset.dataset as ds
|
||||||
|
@ -37,20 +38,19 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
torch.manual_seed(conf.seed)
|
torch.manual_seed(conf.seed)
|
||||||
|
|
||||||
qwen = QwenModule.load_from_checkpoint(checkpoint_path = "log/bigger/version_1/checkpoints/epoch=26-step=27891.ckpt")
|
checkpoint_path = "log/bigger/version_1/checkpoints/epoch=23-step=24792.ckpt"
|
||||||
|
qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
||||||
qwen.eval()
|
qwen.eval()
|
||||||
|
|
||||||
runner = QwenRunner(qwen.llm)
|
runner = QwenRunner(qwen.llm)
|
||||||
|
|
||||||
train_dataloader, val_dataloader = ds.InitDataset(conf)
|
val = ds.InitValDataset(conf).dataset
|
||||||
it = iter(val_dataloader)
|
data = val.dataset
|
||||||
batch = next(it)
|
item = data.get_token(0)
|
||||||
|
print(data.print_tree(0))
|
||||||
|
|
||||||
fdsafd = batch["input_ids"].numpy()
|
batch = torch.tensor([item[:-1]], dtype=torch.int64)
|
||||||
|
batch = batch.cuda()
|
||||||
|
print(item)
|
||||||
print(batch["input_ids"].numpy())
|
next_token = runner.ChatToken(batch)
|
||||||
print(batch["input_ids"][0:1,:-1].numpy())
|
|
||||||
next_token = runner.ChatToken(batch["input_ids"][0:1,:-1].cuda())
|
|
||||||
print(next_token.detach().cpu().numpy())
|
print(next_token.detach().cpu().numpy())
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue