import torch from modelscope import snapshot_download from modeling_wit import QWenLMHeadModel from modeling_wit import QwenRunner from configuration_qwen import QWenConfig from tokenization_qwen import QWenTokenizer seed = 4321 torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) model_dir = snapshot_download("qwen/Qwen-1_8B-Chat") # model_dir = "/home/colin/.cache/modelscope/hub/qwen/Qwen-1_8B-Chat" config = QWenConfig() model = QWenLMHeadModel(config) print(model) tokenizer = QWenTokenizer("./qwen.tiktoken") model = model.from_pretrained(model_dir).cuda() model = model.eval() # model = model.train() # control by @torch.no_grad() runner = QwenRunner(model) response, history, decode_tokens = runner.Chat(tokenizer, "东南亚国家日本的首都是什么市", "") print(decode_tokens)