diff --git a/wit/dataset/meaning_dataset.py b/wit/dataset/meaning_dataset.py index 4f220dc..5c33928 100644 --- a/wit/dataset/meaning_dataset.py +++ b/wit/dataset/meaning_dataset.py @@ -183,7 +183,7 @@ class MeaningMap: self.ms_rank_all[start : start + len], ) - def get_tree(self, meaning): # return meaning all sub items + def get_nodetree(self, meaning): # return meaning all sub items def get_tree_node(ms_map, meaning, vocab_size, parent, seqlist): ms = ms_map[meaning] for m in ms[ms >= 0].tolist(): @@ -200,6 +200,22 @@ class MeaningMap: root.seq_node = seqlist return root + def get_tree(self, meaning): + def get_tree_list(ms_map, meaning, mlist): + ms = ms_map[meaning] + mlist.append(meaning) + mlist.append(-255) # level down marker + for m in ms[ms >= 0].tolist(): + if m >= self.vocab_size: + get_tree_list(ms_map, m, mlist) + else: + mlist.append(m) + mlist.append(-1) # level up marker + + meaninglist = [] + get_tree_list(self.ms_map, meaning, meaninglist) + return meaninglist + def max_length(self): return max(self.ms_len) @@ -314,6 +330,7 @@ class MeaningDataset(Dataset): output["labels"] = data.clone() output["token_type_ids"] = torch.zeros(data.shape) output["val_mask"] = self.get_seq_mask_tensor(idx_list) + output["meaning"] = [self.seq_meaning[i] for i in idx_list] return output def get_token(self, idx): # must equal sequence length @@ -405,17 +422,6 @@ class BatchGroupMeaningDataloader(Dataset): def __getitem__(self, idx): return self.meaning_dataset.get_batch(self.indexBatch[idx]) - def get_tree(self, idx): - return [self.meaning_dataset.get_tree(i) for i in self.indexBatch[idx]] - - def print_tree(self, idx): - idx_list = self.indexBatch[idx] - s = "--------------------------------------------------------\n" - for i in idx_list: - s += self.meaning_dataset.print_tree(i) - s += "--------------------------------------------------------\n" - return s - def detection_collate(batch): return batch[0] @@ -424,7 +430,6 @@ class BatchGroupMeaningDataloader(Dataset): self, batch_size=1, num_workers=num_workers, collate_fn=BatchGroupMeaningDataloader.detection_collate ) - if __name__ == "__main__": md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=False) diff --git a/wit/inference.py b/wit/inference.py index 5d2a946..c231396 100644 --- a/wit/inference.py +++ b/wit/inference.py @@ -12,7 +12,8 @@ import dataset.node_tree as nt if __name__ == "__main__": - checkpoint_path = "log/bigger/version_0/checkpoints/epoch=19-step=98720.ckpt" + # checkpoint_path = "log/bigger/version_0/checkpoints/epoch=19-step=98720.ckpt" + checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt" qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path) qwen.eval() @@ -21,23 +22,24 @@ if __name__ == "__main__": np.random.seed(conf.seed) runner = QwenRunner(qwen.llm) - # batch = torch.tensor([[41]], dtype=torch.int64) - # print(runner.ChatTokens(batch).detach().cpu().numpy()[0]) + # batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7, 14, 13, 1, 12, 13]], dtype=torch.int64) + # sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False) + # print(sorted_logits.detach().cpu().numpy()) + # print(sorted_indices.detach().cpu().numpy()) val = ds.InitValDataset(conf).dataset md = val.meaning_dataset - map = md.get_meaning_map() - item = md.get_token(0) - node = map.get_tree(md.get_meaning(0)) + node = map.get_nodetree(md.get_meaning(0)) # node.print() for i in range(1, len(item)): itemm = [item[:i]] batch = torch.tensor([item[:i]], dtype=torch.int64) - next_token = runner.ChatTokens(batch, sample=False).detach().cpu().numpy()[0] + sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False) + next_token = sorted_indices.detach().cpu().numpy()[0][0] if item[i] != next_token: node.set_seq_prop(i, "ERR_" + str(next_token)) print(str(item[i]) + " " + str(next_token) + " ERROR") diff --git a/wit/model/modeling_wit.py b/wit/model/modeling_wit.py index 0fc20fd..4a609ee 100644 --- a/wit/model/modeling_wit.py +++ b/wit/model/modeling_wit.py @@ -202,12 +202,11 @@ class QwenRunner: outputs, loss = self.forwardQWen(input_ids) next_token_scores = outputs[:, -1, :] next_token_scores = self.repetition_penalty(input_ids, next_token_scores) - next_token_scores = self.top_p(next_token_scores) if sample: + next_token_scores = self.top_p(next_token_scores) return self.sample(next_token_scores) else: - sorted_logits, sorted_indices = torch.sort(next_token_scores, descending=True) - return sorted_indices[0] + return torch.sort(next_token_scores, descending=True) @torch.no_grad() def Chat( diff --git a/wit/query_meaning_freq.py b/wit/query_meaning_freq.py new file mode 100644 index 0000000..ef4725f --- /dev/null +++ b/wit/query_meaning_freq.py @@ -0,0 +1,37 @@ +import pytorch_lightning as pl +import torch + +from model.qwen_module import QwenModule +from model.modeling_wit import QwenRunner +from model.tokenization_qwen import QWenTokenizer +import numpy as np + +import configuration +import dataset.dataset as ds +import dataset.node_tree as nt + +if __name__ == "__main__": + + checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt" + + qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path) + qwen.eval() + conf = qwen.config + torch.manual_seed(conf.seed) + np.random.seed(conf.seed) + + train_dataloader, val_dataloader = ds.InitDataset(conf) + + loader = train_dataloader.dataset + + map = loader.meaning_dataset.get_meaning_map() + trees = {} + for batch in loader: + for m in batch["meaning"]: + trees[m] = map.get_tree(m) + while True: + m = int(input("input meaning: ")) + total = 0 + for tree in trees.values(): + total = total + tree.count(m) + print(f"meaning of {m} count as {total}")