32 lines
811 B
Python
32 lines
811 B
Python
import torch
|
|
from modelscope import snapshot_download
|
|
|
|
from modeling_wit import QWenLMHeadModel
|
|
from modeling_wit import QwenRunner
|
|
from configuration_qwen import QWenConfig
|
|
from tokenization_qwen import QWenTokenizer
|
|
|
|
seed = 4321
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
model_dir = snapshot_download("qwen/Qwen-1_8B-Chat")
|
|
# model_dir = "/home/colin/.cache/modelscope/hub/qwen/Qwen-1_8B-Chat"
|
|
|
|
config = QWenConfig()
|
|
model = QWenLMHeadModel(config)
|
|
|
|
print(model)
|
|
|
|
tokenizer = QWenTokenizer("./qwen.tiktoken")
|
|
model = model.from_pretrained(model_dir).cuda()
|
|
|
|
model = model.eval()
|
|
# model = model.train() # control by @torch.no_grad()
|
|
|
|
|
|
runner = QwenRunner(model)
|
|
|
|
response, history, decode_tokens = runner.Chat(tokenizer, "东南亚国家日本的首都是什么市", "")
|
|
print(decode_tokens)
|