Witllm/wit/inference.py

57 lines
1.6 KiB
Python
Raw Normal View History

2024-03-20 22:27:28 +08:00
import pytorch_lightning as pl
import torch
2025-02-21 15:51:27 +08:00
from model.qwen_module import QwenModule
from model.modeling_wit import QwenRunner
from model.tokenization_qwen import QWenTokenizer
2024-03-20 22:27:28 +08:00
2025-02-21 15:51:27 +08:00
import configuration
import dataset.dataset as ds
2024-03-20 22:27:28 +08:00
if __name__ == "__main__":
2025-02-21 15:51:27 +08:00
conf = configuration.TrainConfig()
config = conf.model_config
conf.name = "bigger" # current train process name
conf.pretrain_model_name = None # "qwen/Qwen-1_8B-Chat"
conf.learning_rate = 0.0001
conf.use_tril_attention_mask = None
conf.precision = "bf16-mixed" # "precision:bf16-mixed,16-mixed,32-true"
conf.train_batch_size = 16
conf.val_batch_size = 4
conf.num_proc = 8
conf.max_epochs = 1000
conf.strategy = "auto"
conf.resume_from_ckpt_path = None
conf.seed = 42
conf.dataloader_works = 2
conf.dataset.meaning.val_mask_level = [0, 1, 2]
conf.dataset.meaning.val_mask_idx = [0, 0, -1]
config.vocab_size = 256
config.hidden_size = 128 # 128 1024 2048 32
config.num_hidden_layers = 3 # 6 12 24 3
2024-03-20 22:27:28 +08:00
config.num_attention_heads = 16 # 8 8 16
2025-02-21 15:51:27 +08:00
torch.manual_seed(conf.seed)
2024-03-20 22:27:28 +08:00
2025-02-21 15:51:27 +08:00
qwen = QwenModule.load_from_checkpoint(checkpoint_path = "log/bigger/version_1/checkpoints/epoch=26-step=27891.ckpt")
qwen.eval()
2024-03-20 22:27:28 +08:00
2025-02-21 15:51:27 +08:00
runner = QwenRunner(qwen.llm)
2024-03-20 22:27:28 +08:00
2025-02-21 15:51:27 +08:00
train_dataloader, val_dataloader = ds.InitDataset(conf)
2024-03-20 22:27:28 +08:00
it = iter(val_dataloader)
batch = next(it)
2025-02-21 15:51:27 +08:00
fdsafd = batch["input_ids"].numpy()
print(batch["input_ids"].numpy())
print(batch["input_ids"][0:1,:-1].numpy())
next_token = runner.ChatToken(batch["input_ids"][0:1,:-1].cuda())
print(next_token.detach().cpu().numpy())
2024-03-20 22:27:28 +08:00