import pytorch_lightning as pl import torch from model.lit_module import LitModule from wit.model.tokenization_qwen import QWenTokenizer from logger import MLFLogger import configuration import dataset.dataset as ds if __name__ == "__main__": train_config = configuration.TrainConfig() config = train_config.model_config torch.manual_seed(train_config.seed) config.vocab_size = 256 config.hidden_size = 128 # 128 1024 2048 32 config.num_hidden_layers = 6 # 6 12 24 3 config.num_attention_heads = 16 # 8 8 16 lit_module = LitModule( train_config.pretrain_model_name, train_config.learning_rate, config, train_config.use_tril_attention_mask ) tokenizer = QWenTokenizer("./model/wit_b64.tiktoken", "./model/wit_char.tiktoken") train_dataloader, val_dataloader = ds.InitDataset(train_config) # for i in range(len(train_dataloader)): # print(train_dataloader.print_mapping(i)) torch.set_float32_matmul_precision("medium") lit_trainer = pl.Trainer( accelerator="cuda", precision=train_config.precision, logger=MLFLogger("./log/", run_name=train_config.name), strategy=train_config.strategy, max_epochs=train_config.max_epochs, ) lit_trainer.fit( lit_module, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, ckpt_path=train_config.resume_from_ckpt_path, )