Witllm/wit/train.py

61 lines
1.9 KiB
Python
Raw Normal View History

2024-02-25 20:20:32 +08:00
import pytorch_lightning as pl
import torch
2024-03-04 21:41:46 +08:00
2025-02-17 19:41:40 +08:00
from model.lit_module import LitModule
from wit.model.tokenization_qwen import QWenTokenizer
2025-02-18 19:35:23 +08:00
from logger import MLFLogger, TBLogger
2024-02-25 20:20:32 +08:00
2025-02-17 19:41:40 +08:00
import configuration
2025-02-18 14:21:15 +08:00
import dataset.dataset as ds
2024-02-25 20:20:32 +08:00
if __name__ == "__main__":
2025-02-18 19:35:23 +08:00
conf = configuration.TrainConfig()
config = conf.model_config
conf.name = "bigger" # current train process name
conf.pretrain_model_name = None # "qwen/Qwen-1_8B-Chat"
conf.learning_rate = 0.0001
conf.use_tril_attention_mask = None
conf.precision = "32-true" # "precision:bf16-mixed,16-mixed,32-true"
conf.train_batch_size = 8
conf.val_batch_size = 4
conf.num_proc = 8
conf.max_epochs = 1000
conf.strategy = "auto"
conf.resume_from_ckpt_path = None
conf.seed = 42
conf.dataloader_works = 2
conf.mask_level = None # [0, 1, 2]
conf.mask_idx = None # [0, 0, -1]
2025-02-17 19:41:40 +08:00
config.vocab_size = 256
config.hidden_size = 128 # 128 1024 2048 32
config.num_hidden_layers = 6 # 6 12 24 3
config.num_attention_heads = 16 # 8 8 16
2024-07-31 22:04:01 +08:00
2025-02-18 19:35:23 +08:00
torch.manual_seed(conf.seed)
lit_module = LitModule(conf.pretrain_model_name, conf.learning_rate, config, conf.use_tril_attention_mask)
2025-02-17 19:41:40 +08:00
tokenizer = QWenTokenizer("./model/wit_b64.tiktoken", "./model/wit_char.tiktoken")
2024-04-03 17:09:30 +08:00
2025-02-18 19:35:23 +08:00
train_dataloader, val_dataloader = ds.InitDataset(conf)
2024-04-03 17:09:30 +08:00
# for i in range(len(train_dataloader)):
# print(train_dataloader.print_mapping(i))
2024-02-26 00:31:47 +08:00
2024-02-25 20:20:32 +08:00
torch.set_float32_matmul_precision("medium")
2024-03-05 22:09:28 +08:00
lit_trainer = pl.Trainer(
2024-03-25 19:53:11 +08:00
accelerator="cuda",
2025-02-18 19:35:23 +08:00
precision=conf.precision,
# logger=MLFLogger("./log/", run_name=conf.name),
logger=TBLogger("./log/", name=conf.name),
strategy=conf.strategy,
max_epochs=conf.max_epochs,
2024-03-05 22:09:28 +08:00
)
2024-02-25 20:20:32 +08:00
lit_trainer.fit(
lit_module,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader,
2025-02-18 19:35:23 +08:00
ckpt_path=conf.resume_from_ckpt_path,
2024-02-25 20:20:32 +08:00
)