From 4c7fdbe8171aab14f5fe75905ec4fe77d7a6584c Mon Sep 17 00:00:00 2001 From: Colin Date: Mon, 25 Mar 2024 13:20:17 +0800 Subject: [PATCH] Add GPU stress test. --- wit/stress.py | 82 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 wit/stress.py diff --git a/wit/stress.py b/wit/stress.py new file mode 100644 index 0000000..c41f5fe --- /dev/null +++ b/wit/stress.py @@ -0,0 +1,82 @@ +import pytorch_lightning as pl +import torch +from torch.utils.data import DataLoader, Dataset, random_split + +from lit_module import LitModule +from logger import TBLogger + +from wit.configuration import ModelConfig + +pretrain_model_name = None # "qwen/Qwen-1_8B-Chat" +learning_rate = 0.0001 +use_tril_attention_mask = None +precision = "32-true" # "precision:bf16-mixed,16-mixed,32-true" +train_batch_size = 4 +val_batch_size = 8 +num_proc = 8 +max_epochs = 1000 +strategy = "auto" +resume_from_ckpt_path = None +seed = 42 + + +class StressDataset(Dataset): + def __init__(self, start=1, end=128, size=32768): # 1048576 32768 + self.size = size + self.features = [] + self.data = torch.randint(start, end, [size, 2048]).long() + + def __len__(self): + return self.size + + def __getitem__(self, idx): + output = {} + data = self.data[idx] + output["input_ids"] = data + output["labels"] = data.clone() + output["token_type_ids"] = torch.zeros(data.shape) + return output + + +if __name__ == "__main__": + torch.manual_seed(seed) + + config = ModelConfig() + config.vocab_size = 4096 + config.hidden_size = 1024 # 128 1024 2048 32 + config.num_hidden_layers = 6 # 6 12 24 3 + config.num_attention_heads = 8 # 8 8 16 + + lit_module = LitModule(pretrain_model_name, learning_rate, config, use_tril_attention_mask) + + raw_dataset = StressDataset() + train_dataset, val_dataset = random_split(raw_dataset, [0.95, 0.05]) + + train_dataloader = DataLoader( + train_dataset, + batch_size=train_batch_size, + num_workers=num_proc, + persistent_workers=True, + shuffle=True, + ) + val_dataloader = DataLoader( + val_dataset, + batch_size=val_batch_size, + num_workers=num_proc, + persistent_workers=True, + ) + + lit_trainer = pl.Trainer( + accelerator="gpu", + devices=2, + precision=precision, + logger=TBLogger("./", default_hp_metric=False), + strategy=strategy, + max_epochs=max_epochs, + ) + lit_trainer.fit( + lit_module, + train_dataloaders=train_dataloader, + val_dataloaders=val_dataloader, + ckpt_path=resume_from_ckpt_path, + )