Add inference.
This commit is contained in:
		
							parent
							
								
									b248d1d890
								
							
						
					
					
						commit
						01e5f86e94
					
				
							
								
								
									
										25
									
								
								test/loss.py
								
								
								
								
							
							
						
						
									
										25
									
								
								test/loss.py
								
								
								
								
							| 
						 | 
				
			
			@ -20,27 +20,28 @@ import torchmetrics
 | 
			
		|||
# print(output)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
target = torch.tensor([0, 1, 2])
 | 
			
		||||
preds = torch.tensor([[0.1, 0.9, 0], [0.3, 10.1, 0.6], [0.2, 0.3, 0.9]])
 | 
			
		||||
accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=3)
 | 
			
		||||
accur = accuracy(preds, target)
 | 
			
		||||
# target = torch.tensor([0, 1, 2])
 | 
			
		||||
# preds = torch.tensor([[0.1, 0.9, 0], [0.3, 10.1, 0.6], [0.2, 0.3, 0.9]])
 | 
			
		||||
# accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=3)
 | 
			
		||||
# accur = accuracy(preds, target)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
metric_accuracy = torchmetrics.Accuracy(
 | 
			
		||||
    task="multiclass",
 | 
			
		||||
    num_classes=4096,
 | 
			
		||||
    num_classes=4,
 | 
			
		||||
)
 | 
			
		||||
shift_logits = torch.zeros((16, 2, 4096))
 | 
			
		||||
shift_logits[:8, :, 2] = 10.0
 | 
			
		||||
shift_labels = (torch.ones((16, 2)) * 2).long()
 | 
			
		||||
 | 
			
		||||
label_mask = shift_labels != 4096
 | 
			
		||||
shift_logits = shift_logits[label_mask]
 | 
			
		||||
shift_labels = shift_labels[label_mask]
 | 
			
		||||
 | 
			
		||||
shift_logits = torch.rand((128, 4))
 | 
			
		||||
shift_labels = torch.randint(0, 4, size=(128,))
 | 
			
		||||
accur = metric_accuracy(shift_logits, shift_labels)
 | 
			
		||||
metric_accuracy.update(shift_logits, shift_labels)
 | 
			
		||||
print(accur.numpy())
 | 
			
		||||
 | 
			
		||||
shift_logits = torch.cat((shift_logits, shift_logits), dim=0)
 | 
			
		||||
shift_labels = torch.cat((shift_labels, shift_labels), dim=0)
 | 
			
		||||
accur = metric_accuracy(shift_logits, shift_labels)
 | 
			
		||||
print(accur.numpy())
 | 
			
		||||
print(accur.numpy())
 | 
			
		||||
 | 
			
		||||
# torch.manual_seed(32)
 | 
			
		||||
# criterion = nn.CrossEntropyLoss()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,69 @@
 | 
			
		|||
import argparse
 | 
			
		||||
from functools import partial
 | 
			
		||||
from itertools import chain
 | 
			
		||||
from typing import Dict, Tuple
 | 
			
		||||
 | 
			
		||||
import datasets
 | 
			
		||||
import pytorch_lightning as pl
 | 
			
		||||
import torch
 | 
			
		||||
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split, Subset
 | 
			
		||||
 | 
			
		||||
from lit_module import LitModule
 | 
			
		||||
from tokenization_qwen import QWenTokenizer
 | 
			
		||||
from logger import TBLogger
 | 
			
		||||
 | 
			
		||||
from special_dataset import SpecialDataset
 | 
			
		||||
from meaning_dataset import MeaningDataset, BatchGroupMeaningDataloader
 | 
			
		||||
from wit.configuration import ModelConfig
 | 
			
		||||
 | 
			
		||||
pretrain_model_name = None  # "qwen/Qwen-1_8B-Chat"
 | 
			
		||||
learning_rate = 0.0001
 | 
			
		||||
use_tril_attention_mask = None
 | 
			
		||||
precision = "32-true"  # "precision:bf16-mixed,16-mixed,32-true"
 | 
			
		||||
train_batch_size = 1
 | 
			
		||||
val_batch_size = 2
 | 
			
		||||
num_proc = 8
 | 
			
		||||
max_epochs = 10
 | 
			
		||||
strategy = "auto"
 | 
			
		||||
resume_from_ckpt_path = None
 | 
			
		||||
seed = 42
 | 
			
		||||
vocab_size = 16
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    torch.manual_seed(seed)
 | 
			
		||||
 | 
			
		||||
    config = ModelConfig()
 | 
			
		||||
    config.vocab_size = vocab_size
 | 
			
		||||
    config.hidden_size = 1024  # 128 1024 2048  32
 | 
			
		||||
    config.num_hidden_layers = 1  # 6 12 24  3
 | 
			
		||||
    config.num_attention_heads = 16  # 8 8 16
 | 
			
		||||
 | 
			
		||||
    lit_module = LitModule(pretrain_model_name, learning_rate, config, use_tril_attention_mask)
 | 
			
		||||
    tokenizer = QWenTokenizer("./wit_b64.tiktoken", "./wit_char.tiktoken")
 | 
			
		||||
 | 
			
		||||
    level_ratio = 2
 | 
			
		||||
    start = vocab_size * level_ratio * level_ratio
 | 
			
		||||
    end = start * level_ratio
 | 
			
		||||
    size = end * level_ratio
 | 
			
		||||
    size = 1024
 | 
			
		||||
    raw_dataset = MeaningDataset(start, end, size, vocab_size, level_ratio)
 | 
			
		||||
    train_dataset, val_dataset = raw_dataset.Split(0.95)
 | 
			
		||||
 | 
			
		||||
    train_dataloader = BatchGroupMeaningDataloader(train_dataset, train_batch_size)
 | 
			
		||||
    val_dataloader = BatchGroupMeaningDataloader(val_dataset, val_batch_size)
 | 
			
		||||
 | 
			
		||||
    it = iter(val_dataloader)
 | 
			
		||||
    batch = next(it)
 | 
			
		||||
    b, l = lit_module.llm(**batch, return_dict=True)
 | 
			
		||||
    print("b ")
 | 
			
		||||
    print(b.detach().cpu().numpy())
 | 
			
		||||
 | 
			
		||||
    # batch["input_ids"] = batch["input_ids"][0:1, :]
 | 
			
		||||
    batch["input_ids"] = batch["input_ids"][1:2, :]
 | 
			
		||||
    batch["labels"] = batch["labels"][1:2, :]
 | 
			
		||||
    s, l = lit_module.llm(**batch, return_dict=True)
 | 
			
		||||
    print("s ")
 | 
			
		||||
    print(s.detach().cpu().numpy())
 | 
			
		||||
 | 
			
		||||
    print("data samples:")
 | 
			
		||||
| 
						 | 
				
			
			@ -188,14 +188,14 @@ class BatchGroupMeaningDataloader(Dataset):
 | 
			
		|||
            index_shuffle = np.arange(0, index.shape[0])
 | 
			
		||||
            np.random.shuffle(index_shuffle)
 | 
			
		||||
            index = index[index_shuffle]
 | 
			
		||||
        self.index = index
 | 
			
		||||
        self.indexBatch = index
 | 
			
		||||
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
        return len(self.index)
 | 
			
		||||
        return len(self.indexBatch)
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, idx):
 | 
			
		||||
        # print("get idx" + str(idx))
 | 
			
		||||
        return self.dataset.GetBatch(self.index[idx])
 | 
			
		||||
        return self.dataset.GetBatch(self.indexBatch[idx])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -20,7 +20,7 @@ class SpecialDataset(Dataset):
 | 
			
		|||
        z = torch.zeros([size]).long()
 | 
			
		||||
        # self.data = torch.stack([a, b, a + b, a + b, a + b * 2]).permute(1, 0)
 | 
			
		||||
        # self.data = torch.stack([a, b, a, a + b / 4]).permute(1, 0).long()
 | 
			
		||||
        self.data = torch.stack([a, a + a, a + a]).permute(1, 0).long()
 | 
			
		||||
        self.data = torch.stack([a, a, a + a]).permute(1, 0).long()
 | 
			
		||||
        # self.data = torch.stack([a, b, a]).permute(1, 0).long()
 | 
			
		||||
        # self.data = torch.stack([a, b, a, a + a / 8, a + a / 4, a + a / 2, a + a]).permute(1, 0).long()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue