46 lines
1.4 KiB
Python
46 lines
1.4 KiB
Python
import pytorch_lightning as pl
|
|
import torch
|
|
|
|
from model.lit_module import LitModule
|
|
from wit.model.tokenization_qwen import QWenTokenizer
|
|
from logger import MLFLogger
|
|
|
|
import configuration
|
|
import dataset as ds
|
|
|
|
if __name__ == "__main__":
|
|
|
|
train_config = configuration.TrainConfig()
|
|
config = train_config.model_config
|
|
|
|
torch.manual_seed(train_config.seed)
|
|
|
|
config.vocab_size = 256
|
|
config.hidden_size = 128 # 128 1024 2048 32
|
|
config.num_hidden_layers = 6 # 6 12 24 3
|
|
config.num_attention_heads = 16 # 8 8 16
|
|
|
|
lit_module = LitModule(
|
|
train_config.pretrain_model_name, train_config.learning_rate, config, train_config.use_tril_attention_mask
|
|
)
|
|
tokenizer = QWenTokenizer("./model/wit_b64.tiktoken", "./model/wit_char.tiktoken")
|
|
|
|
train_dataloader, val_dataloader = ds.InitDataset(train_config)
|
|
# for i in range(len(train_dataloader)):
|
|
# print(train_dataloader.print_mapping(i))
|
|
|
|
torch.set_float32_matmul_precision("medium")
|
|
lit_trainer = pl.Trainer(
|
|
accelerator="cuda",
|
|
precision=train_config.precision,
|
|
logger=MLFLogger("./log/", run_name=train_config.name),
|
|
strategy=train_config.strategy,
|
|
max_epochs=train_config.max_epochs,
|
|
)
|
|
lit_trainer.fit(
|
|
lit_module,
|
|
train_dataloaders=train_dataloader,
|
|
val_dataloaders=val_dataloader,
|
|
ckpt_path=train_config.resume_from_ckpt_path,
|
|
)
|