Witllm/wit/demo.py

32 lines
811 B
Python

import torch
from modelscope import snapshot_download
from modeling_wit import QWenLMHeadModel
from modeling_wit import QwenRunner
from configuration_qwen import QWenConfig
from tokenization_qwen import QWenTokenizer
seed = 4321
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
model_dir = snapshot_download("qwen/Qwen-1_8B-Chat")
# model_dir = "/home/colin/.cache/modelscope/hub/qwen/Qwen-1_8B-Chat"
config = QWenConfig()
model = QWenLMHeadModel(config)
print(model)
tokenizer = QWenTokenizer("./qwen.tiktoken")
model = model.from_pretrained(model_dir).cuda()
model = model.eval()
# model = model.train() # control by @torch.no_grad()
runner = QwenRunner(model)
response, history, decode_tokens = runner.Chat(tokenizer, "东南亚国家日本的首都是什么市", "")
print(decode_tokens)