2024-03-20 22:27:28 +08:00
|
|
|
import pytorch_lightning as pl
|
|
|
|
import torch
|
|
|
|
|
2025-02-21 15:51:27 +08:00
|
|
|
from model.qwen_module import QwenModule
|
|
|
|
from model.modeling_wit import QwenRunner
|
|
|
|
from model.tokenization_qwen import QWenTokenizer
|
2025-02-21 17:28:21 +08:00
|
|
|
import numpy as np
|
2024-03-20 22:27:28 +08:00
|
|
|
|
2025-02-21 15:51:27 +08:00
|
|
|
import configuration
|
|
|
|
import dataset.dataset as ds
|
2024-03-20 22:27:28 +08:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
2025-02-21 15:51:27 +08:00
|
|
|
conf = configuration.TrainConfig()
|
|
|
|
config = conf.model_config
|
|
|
|
|
|
|
|
conf.name = "bigger" # current train process name
|
|
|
|
conf.pretrain_model_name = None # "qwen/Qwen-1_8B-Chat"
|
|
|
|
conf.learning_rate = 0.0001
|
|
|
|
conf.use_tril_attention_mask = None
|
|
|
|
conf.precision = "bf16-mixed" # "precision:bf16-mixed,16-mixed,32-true"
|
|
|
|
conf.train_batch_size = 16
|
|
|
|
conf.val_batch_size = 4
|
|
|
|
conf.num_proc = 8
|
|
|
|
conf.max_epochs = 1000
|
|
|
|
conf.strategy = "auto"
|
|
|
|
conf.resume_from_ckpt_path = None
|
|
|
|
conf.seed = 42
|
|
|
|
conf.dataloader_works = 2
|
|
|
|
|
|
|
|
conf.dataset.meaning.val_mask_level = [0, 1, 2]
|
|
|
|
conf.dataset.meaning.val_mask_idx = [0, 0, -1]
|
|
|
|
|
|
|
|
config.vocab_size = 256
|
|
|
|
config.hidden_size = 128 # 128 1024 2048 32
|
|
|
|
config.num_hidden_layers = 3 # 6 12 24 3
|
2024-03-20 22:27:28 +08:00
|
|
|
config.num_attention_heads = 16 # 8 8 16
|
|
|
|
|
2025-02-21 15:51:27 +08:00
|
|
|
torch.manual_seed(conf.seed)
|
2024-03-20 22:27:28 +08:00
|
|
|
|
2025-02-21 17:28:21 +08:00
|
|
|
checkpoint_path = "log/bigger/version_1/checkpoints/epoch=23-step=24792.ckpt"
|
|
|
|
qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
2025-02-21 15:51:27 +08:00
|
|
|
qwen.eval()
|
2024-03-20 22:27:28 +08:00
|
|
|
|
2025-02-21 15:51:27 +08:00
|
|
|
runner = QwenRunner(qwen.llm)
|
2024-03-20 22:27:28 +08:00
|
|
|
|
2025-02-21 17:28:21 +08:00
|
|
|
val = ds.InitValDataset(conf).dataset
|
|
|
|
data = val.dataset
|
|
|
|
item = data.get_token(0)
|
|
|
|
print(data.print_tree(0))
|
2024-03-20 22:27:28 +08:00
|
|
|
|
2025-02-21 17:28:21 +08:00
|
|
|
batch = torch.tensor([item[:-1]], dtype=torch.int64)
|
|
|
|
batch = batch.cuda()
|
|
|
|
print(item)
|
|
|
|
next_token = runner.ChatToken(batch)
|
2025-02-21 15:51:27 +08:00
|
|
|
print(next_token.detach().cpu().numpy())
|