import pytorch_lightning as pl import torch from model.qwen_module import QwenModule from model.tokenization_qwen import QWenTokenizer from logger import MLFLogger, TBLogger import configuration import dataset.dataset as ds if __name__ == "__main__": conf = configuration.TrainConfig() config = conf.model_config conf.name = "bigger" # current train process name conf.pretrain_model_name = None # "qwen/Qwen-1_8B-Chat" conf.learning_rate = 0.0001 conf.use_tril_attention_mask = None conf.precision = "bf16-mixed" # "precision:bf16-mixed,16-mixed,32-true" conf.train_batch_size = 16 conf.val_batch_size = 4 conf.num_proc = 8 conf.max_epochs = 1000 conf.strategy = "auto" conf.resume_from_ckpt_path = None conf.seed = 42 conf.dataloader_works = 2 conf.dataset.meaning.val_mask_level = [0, 1, 2] conf.dataset.meaning.val_mask_idx = [0, 0, -1] config.vocab_size = 256 config.hidden_size = 128 # 128 1024 2048 32 config.num_hidden_layers = 3 # 6 12 24 3 config.num_attention_heads = 16 # 8 8 16 torch.manual_seed(conf.seed) qwen = QwenModule(conf) train_dataloader, val_dataloader = ds.InitDataset(conf) # for i in range(len(train_dataloader)): # print(train_dataloader.print_mapping(i)) logger = TBLogger("./log/", name=conf.name) logger.log_hyperparams(configuration.class_to_dict(conf)) torch.set_float32_matmul_precision("medium") lit_trainer = pl.Trainer( accelerator="cuda", precision=conf.precision, # logger=MLFLogger("./log/", run_name=conf.name), logger=logger, strategy=conf.strategy, max_epochs=conf.max_epochs, ) lit_trainer.fit( qwen, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, ckpt_path=conf.resume_from_ckpt_path, )