Witllm/wit/train.py

46 lines
1.4 KiB
Python

import pytorch_lightning as pl
import torch
from model.lit_module import LitModule
from wit.model.tokenization_qwen import QWenTokenizer
from logger import MLFLogger
import configuration
import dataset as ds
if __name__ == "__main__":
train_config = configuration.TrainConfig()
config = train_config.model_config
torch.manual_seed(train_config.seed)
config.vocab_size = 256
config.hidden_size = 128 # 128 1024 2048 32
config.num_hidden_layers = 6 # 6 12 24 3
config.num_attention_heads = 16 # 8 8 16
lit_module = LitModule(
train_config.pretrain_model_name, train_config.learning_rate, config, train_config.use_tril_attention_mask
)
tokenizer = QWenTokenizer("./model/wit_b64.tiktoken", "./model/wit_char.tiktoken")
train_dataloader, val_dataloader = ds.InitDataset(train_config)
# for i in range(len(train_dataloader)):
# print(train_dataloader.print_mapping(i))
torch.set_float32_matmul_precision("medium")
lit_trainer = pl.Trainer(
accelerator="cuda",
precision=train_config.precision,
logger=MLFLogger("./log/", run_name=train_config.name),
strategy=train_config.strategy,
max_epochs=train_config.max_epochs,
)
lit_trainer.fit(
lit_module,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader,
ckpt_path=train_config.resume_from_ckpt_path,
)