Witllm/wit/train.py

46 lines
1.4 KiB
Python
Raw Normal View History

2024-02-25 20:20:32 +08:00
import pytorch_lightning as pl
import torch
2024-03-04 21:41:46 +08:00
2025-02-17 19:41:40 +08:00
from model.lit_module import LitModule
from wit.model.tokenization_qwen import QWenTokenizer
from logger import MLFLogger
2024-02-25 20:20:32 +08:00
2025-02-17 19:41:40 +08:00
import configuration
2025-02-18 14:21:15 +08:00
import dataset.dataset as ds
2024-02-25 20:20:32 +08:00
if __name__ == "__main__":
2025-02-17 19:41:40 +08:00
train_config = configuration.TrainConfig()
config = train_config.model_config
2024-03-14 11:40:26 +08:00
2025-02-17 19:41:40 +08:00
torch.manual_seed(train_config.seed)
2024-02-25 20:20:32 +08:00
2025-02-17 19:41:40 +08:00
config.vocab_size = 256
config.hidden_size = 128 # 128 1024 2048 32
config.num_hidden_layers = 6 # 6 12 24 3
config.num_attention_heads = 16 # 8 8 16
2024-07-31 22:04:01 +08:00
2025-02-17 19:41:40 +08:00
lit_module = LitModule(
train_config.pretrain_model_name, train_config.learning_rate, config, train_config.use_tril_attention_mask
)
tokenizer = QWenTokenizer("./model/wit_b64.tiktoken", "./model/wit_char.tiktoken")
2024-04-03 17:09:30 +08:00
2025-02-17 19:41:40 +08:00
train_dataloader, val_dataloader = ds.InitDataset(train_config)
2024-04-03 17:09:30 +08:00
# for i in range(len(train_dataloader)):
# print(train_dataloader.print_mapping(i))
2024-02-26 00:31:47 +08:00
2024-02-25 20:20:32 +08:00
torch.set_float32_matmul_precision("medium")
2024-03-05 22:09:28 +08:00
lit_trainer = pl.Trainer(
2024-03-25 19:53:11 +08:00
accelerator="cuda",
2025-02-17 19:41:40 +08:00
precision=train_config.precision,
logger=MLFLogger("./log/", run_name=train_config.name),
strategy=train_config.strategy,
max_epochs=train_config.max_epochs,
2024-03-05 22:09:28 +08:00
)
2024-02-25 20:20:32 +08:00
lit_trainer.fit(
lit_module,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader,
2025-02-17 19:41:40 +08:00
ckpt_path=train_config.resume_from_ckpt_path,
2024-02-25 20:20:32 +08:00
)