Witllm/wit/demo.py

71 lines
1.6 KiB
Python

import torch
import sys
from modelscope import snapshot_download
from modeling_wit import QWenLMHeadModel
from modeling_wit import QwenRunner
from wit.configuration import ModelConfig
from tokenization_qwen import QWenTokenizer
from qwen_generation_utils import (
make_context,
decode_tokens,
)
seed = 4321
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
model_dir = snapshot_download("qwen/Qwen-1_8B-Chat")
# model_dir = "/home/colin/.cache/modelscope/hub/qwen/Qwen-1_8B-Chat"
config = ModelConfig()
model = QWenLMHeadModel(config)
print(model)
tokenizer = QWenTokenizer("./wit_b64.tiktoken", "./wit_char.tiktoken")
sys.path.append("..")
from tools import show
def Dump_tokens_list(model):
tokens = []
for token in range(4096):
decoded, response, end_reason = decode_tokens(
[token],
tokenizer,
raw_text_len=0,
context_length=0,
errors="replace",
)
tokens.append(str(token).zfill(7) + ": " + repr(decoded))
show.DumpListToFile(tokens, "./temp/qwen_token_list.txt")
Dump_tokens_list(model)
model = model.from_pretrained(model_dir).cuda()
# state = model.state_dict()
# torch.save(state, "model_params.pth")
# model.load_state_dict(torch.load('model_params.pth'))
model = model.eval()
# model = model.train() # control by @torch.no_grad()
runner = QwenRunner(model)
output_ids, history, decode_tokens = runner.Chat(tokenizer, "你好", "", 20)
print(decode_tokens)
for i, token in enumerate(output_ids):
de = tokenizer.decode([token])
de = str(i + 1).zfill(3) + " : " + repr(de)
print(de)