import torch import sys from modelscope import snapshot_download from modeling_wit import QWenLMHeadModel from modeling_wit import QwenRunner from wit.configuration import ModelConfig from tokenization_qwen import QWenTokenizer from qwen_generation_utils import ( make_context, decode_tokens, ) seed = 4321 torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) model_dir = snapshot_download("qwen/Qwen-1_8B-Chat") # model_dir = "/home/colin/.cache/modelscope/hub/qwen/Qwen-1_8B-Chat" config = ModelConfig() model = QWenLMHeadModel(config) print(model) tokenizer = QWenTokenizer("./wit_b64.tiktoken", "./wit_char.tiktoken") sys.path.append("..") from tools import show def Dump_tokens_list(model): tokens = [] for token in range(4096): decoded, response, end_reason = decode_tokens( [token], tokenizer, raw_text_len=0, context_length=0, errors="replace", ) tokens.append(str(token).zfill(7) + ": " + repr(decoded)) show.DumpListToFile(tokens, "./temp/qwen_token_list.txt") Dump_tokens_list(model) model = model.from_pretrained(model_dir).cuda() # state = model.state_dict() # torch.save(state, "model_params.pth") # model.load_state_dict(torch.load('model_params.pth')) model = model.eval() # model = model.train() # control by @torch.no_grad() runner = QwenRunner(model) output_ids, history, decode_tokens = runner.Chat(tokenizer, "你好", "", 20) print(decode_tokens) for i, token in enumerate(output_ids): de = tokenizer.decode([token]) de = str(i + 1).zfill(3) + " : " + repr(de) print(de)