71 lines
1.6 KiB
Python
71 lines
1.6 KiB
Python
import torch
|
|
import sys
|
|
from modelscope import snapshot_download
|
|
|
|
from modeling_wit import QWenLMHeadModel
|
|
from modeling_wit import QwenRunner
|
|
from wit.configuration import ModelConfig
|
|
from tokenization_qwen import QWenTokenizer
|
|
|
|
|
|
from qwen_generation_utils import (
|
|
make_context,
|
|
decode_tokens,
|
|
)
|
|
|
|
seed = 4321
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
model_dir = snapshot_download("qwen/Qwen-1_8B-Chat")
|
|
# model_dir = "/home/colin/.cache/modelscope/hub/qwen/Qwen-1_8B-Chat"
|
|
|
|
config = ModelConfig()
|
|
model = QWenLMHeadModel(config)
|
|
|
|
print(model)
|
|
|
|
tokenizer = QWenTokenizer("./wit_b64.tiktoken", "./wit_char.tiktoken")
|
|
|
|
sys.path.append("..")
|
|
from tools import show
|
|
|
|
|
|
def Dump_tokens_list(model):
|
|
tokens = []
|
|
for token in range(4096):
|
|
decoded, response, end_reason = decode_tokens(
|
|
[token],
|
|
tokenizer,
|
|
raw_text_len=0,
|
|
context_length=0,
|
|
errors="replace",
|
|
)
|
|
tokens.append(str(token).zfill(7) + ": " + repr(decoded))
|
|
show.DumpListToFile(tokens, "./temp/qwen_token_list.txt")
|
|
|
|
|
|
Dump_tokens_list(model)
|
|
|
|
|
|
model = model.from_pretrained(model_dir).cuda()
|
|
|
|
# state = model.state_dict()
|
|
# torch.save(state, "model_params.pth")
|
|
# model.load_state_dict(torch.load('model_params.pth'))
|
|
|
|
|
|
model = model.eval()
|
|
# model = model.train() # control by @torch.no_grad()
|
|
|
|
|
|
runner = QwenRunner(model)
|
|
|
|
output_ids, history, decode_tokens = runner.Chat(tokenizer, "你好", "", 20)
|
|
print(decode_tokens)
|
|
|
|
for i, token in enumerate(output_ids):
|
|
de = tokenizer.decode([token])
|
|
de = str(i + 1).zfill(3) + " : " + repr(de)
|
|
print(de)
|