From 01e5f86e94ec7de28b3f7353a575f0bef16d289e Mon Sep 17 00:00:00 2001 From: Colin Date: Wed, 20 Mar 2024 22:27:28 +0800 Subject: [PATCH] Add inference. --- test/loss.py | 25 +++++++-------- wit/inference.py | 69 ++++++++++++++++++++++++++++++++++++++++++ wit/meaning_dataset.py | 6 ++-- wit/special_dataset.py | 2 +- 4 files changed, 86 insertions(+), 16 deletions(-) create mode 100644 wit/inference.py diff --git a/test/loss.py b/test/loss.py index 8df1d6a..74659a5 100644 --- a/test/loss.py +++ b/test/loss.py @@ -20,27 +20,28 @@ import torchmetrics # print(output) -target = torch.tensor([0, 1, 2]) -preds = torch.tensor([[0.1, 0.9, 0], [0.3, 10.1, 0.6], [0.2, 0.3, 0.9]]) -accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=3) -accur = accuracy(preds, target) +# target = torch.tensor([0, 1, 2]) +# preds = torch.tensor([[0.1, 0.9, 0], [0.3, 10.1, 0.6], [0.2, 0.3, 0.9]]) +# accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=3) +# accur = accuracy(preds, target) metric_accuracy = torchmetrics.Accuracy( task="multiclass", - num_classes=4096, + num_classes=4, ) -shift_logits = torch.zeros((16, 2, 4096)) -shift_logits[:8, :, 2] = 10.0 -shift_labels = (torch.ones((16, 2)) * 2).long() -label_mask = shift_labels != 4096 -shift_logits = shift_logits[label_mask] -shift_labels = shift_labels[label_mask] +shift_logits = torch.rand((128, 4)) +shift_labels = torch.randint(0, 4, size=(128,)) accur = metric_accuracy(shift_logits, shift_labels) -metric_accuracy.update(shift_logits, shift_labels) +print(accur.numpy()) +shift_logits = torch.cat((shift_logits, shift_logits), dim=0) +shift_labels = torch.cat((shift_labels, shift_labels), dim=0) +accur = metric_accuracy(shift_logits, shift_labels) +print(accur.numpy()) +print(accur.numpy()) # torch.manual_seed(32) # criterion = nn.CrossEntropyLoss() diff --git a/wit/inference.py b/wit/inference.py new file mode 100644 index 0000000..986f284 --- /dev/null +++ b/wit/inference.py @@ -0,0 +1,69 @@ +import argparse +from functools import partial +from itertools import chain +from typing import Dict, Tuple + +import datasets +import pytorch_lightning as pl +import torch +from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split, Subset + +from lit_module import LitModule +from tokenization_qwen import QWenTokenizer +from logger import TBLogger + +from special_dataset import SpecialDataset +from meaning_dataset import MeaningDataset, BatchGroupMeaningDataloader +from wit.configuration import ModelConfig + +pretrain_model_name = None # "qwen/Qwen-1_8B-Chat" +learning_rate = 0.0001 +use_tril_attention_mask = None +precision = "32-true" # "precision:bf16-mixed,16-mixed,32-true" +train_batch_size = 1 +val_batch_size = 2 +num_proc = 8 +max_epochs = 10 +strategy = "auto" +resume_from_ckpt_path = None +seed = 42 +vocab_size = 16 + + +if __name__ == "__main__": + torch.manual_seed(seed) + + config = ModelConfig() + config.vocab_size = vocab_size + config.hidden_size = 1024 # 128 1024 2048 32 + config.num_hidden_layers = 1 # 6 12 24 3 + config.num_attention_heads = 16 # 8 8 16 + + lit_module = LitModule(pretrain_model_name, learning_rate, config, use_tril_attention_mask) + tokenizer = QWenTokenizer("./wit_b64.tiktoken", "./wit_char.tiktoken") + + level_ratio = 2 + start = vocab_size * level_ratio * level_ratio + end = start * level_ratio + size = end * level_ratio + size = 1024 + raw_dataset = MeaningDataset(start, end, size, vocab_size, level_ratio) + train_dataset, val_dataset = raw_dataset.Split(0.95) + + train_dataloader = BatchGroupMeaningDataloader(train_dataset, train_batch_size) + val_dataloader = BatchGroupMeaningDataloader(val_dataset, val_batch_size) + + it = iter(val_dataloader) + batch = next(it) + b, l = lit_module.llm(**batch, return_dict=True) + print("b ") + print(b.detach().cpu().numpy()) + + # batch["input_ids"] = batch["input_ids"][0:1, :] + batch["input_ids"] = batch["input_ids"][1:2, :] + batch["labels"] = batch["labels"][1:2, :] + s, l = lit_module.llm(**batch, return_dict=True) + print("s ") + print(s.detach().cpu().numpy()) + + print("data samples:") diff --git a/wit/meaning_dataset.py b/wit/meaning_dataset.py index 7b8098c..5f8f032 100644 --- a/wit/meaning_dataset.py +++ b/wit/meaning_dataset.py @@ -188,14 +188,14 @@ class BatchGroupMeaningDataloader(Dataset): index_shuffle = np.arange(0, index.shape[0]) np.random.shuffle(index_shuffle) index = index[index_shuffle] - self.index = index + self.indexBatch = index def __len__(self): - return len(self.index) + return len(self.indexBatch) def __getitem__(self, idx): # print("get idx" + str(idx)) - return self.dataset.GetBatch(self.index[idx]) + return self.dataset.GetBatch(self.indexBatch[idx]) if __name__ == "__main__": diff --git a/wit/special_dataset.py b/wit/special_dataset.py index 243f333..d69699b 100644 --- a/wit/special_dataset.py +++ b/wit/special_dataset.py @@ -20,7 +20,7 @@ class SpecialDataset(Dataset): z = torch.zeros([size]).long() # self.data = torch.stack([a, b, a + b, a + b, a + b * 2]).permute(1, 0) # self.data = torch.stack([a, b, a, a + b / 4]).permute(1, 0).long() - self.data = torch.stack([a, a + a, a + a]).permute(1, 0).long() + self.data = torch.stack([a, a, a + a]).permute(1, 0).long() # self.data = torch.stack([a, b, a]).permute(1, 0).long() # self.data = torch.stack([a, b, a, a + a / 8, a + a / 4, a + a / 2, a + a]).permute(1, 0).long()