2024-01-03 21:03:27 +08:00
|
|
|
import torch
|
2024-01-03 20:26:26 +08:00
|
|
|
from modelscope import snapshot_download
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
from transformers.generation import GenerationConfig
|
2024-01-03 21:03:27 +08:00
|
|
|
from transformers import AutoConfig
|
|
|
|
|
|
|
|
from modeling_qwen import QWenLMHeadModel
|
2024-01-21 12:45:56 +08:00
|
|
|
from modeling_qwen import QwenRunner
|
2024-01-03 21:03:27 +08:00
|
|
|
|
|
|
|
seed = 4321
|
|
|
|
torch.manual_seed(seed)
|
|
|
|
torch.cuda.manual_seed_all(seed)
|
2024-01-03 20:26:26 +08:00
|
|
|
|
2024-01-05 11:49:35 +08:00
|
|
|
model_dir = snapshot_download("qwen/Qwen-1_8B-Chat")
|
|
|
|
# model_dir = "/home/colin/.cache/modelscope/hub/qwen/Qwen-1_8B-Chat"
|
2024-01-03 20:26:26 +08:00
|
|
|
|
2024-01-03 21:03:27 +08:00
|
|
|
config, kwargs = AutoConfig.from_pretrained(
|
2024-01-07 17:28:15 +08:00
|
|
|
"./",
|
2024-01-03 21:03:27 +08:00
|
|
|
return_unused_kwargs=True,
|
|
|
|
trust_remote_code=True,
|
|
|
|
code_revision=None,
|
|
|
|
_commit_hash=None,
|
|
|
|
)
|
|
|
|
model = QWenLMHeadModel(config)
|
|
|
|
|
2024-01-10 19:35:46 +08:00
|
|
|
print(model)
|
|
|
|
|
2024-01-03 20:26:26 +08:00
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
2024-01-20 20:04:45 +08:00
|
|
|
model = model.from_pretrained(model_dir).cuda()
|
2024-01-19 14:54:48 +08:00
|
|
|
|
2024-01-20 20:47:26 +08:00
|
|
|
model = model.eval()
|
|
|
|
# model = model.train() # control by @torch.no_grad()
|
2024-01-03 20:26:26 +08:00
|
|
|
|
|
|
|
# 可指定不同的生成长度、top_p等相关超参
|
2024-01-07 17:28:15 +08:00
|
|
|
# model.generation_config = GenerationConfig.from_pretrained(
|
|
|
|
# model_dir, trust_remote_code=True
|
|
|
|
# )
|
2024-01-03 20:26:26 +08:00
|
|
|
|
2024-01-21 12:45:56 +08:00
|
|
|
runner = QwenRunner(model)
|
|
|
|
|
2024-01-03 20:26:26 +08:00
|
|
|
# 第一轮对话
|
2024-01-21 12:45:56 +08:00
|
|
|
response, history, decode_tokens = runner.Chat(tokenizer, "东南亚国家日本的首都是什么市", "")
|
2024-01-10 19:35:46 +08:00
|
|
|
print(decode_tokens)
|
2024-01-13 16:50:25 +08:00
|
|
|
# <|im_start|>system
|
|
|
|
# You are a helpful assistant.<|im_end|>
|
|
|
|
# <|im_start|>user
|
2024-01-14 17:21:14 +08:00
|
|
|
# 东南亚国家日本的首都是什么市<|im_end|>
|
2024-01-13 16:50:25 +08:00
|
|
|
# <|im_start|>assistant
|
2024-01-14 17:21:14 +08:00
|
|
|
# 日本的首都东京。<|im_end|><|endoftext|>
|
|
|
|
|
|
|
|
# 第二轮对话
|
2024-01-21 12:45:56 +08:00
|
|
|
|
|
|
|
response, history, decode_tokens = runner.Chat(tokenizer, "给我讲一个年轻人奋斗创业最终取得成功的故事。", "")
|
2024-01-20 20:04:45 +08:00
|
|
|
print(decode_tokens)
|
2024-01-13 16:50:25 +08:00
|
|
|
|
2024-01-21 02:33:55 +08:00
|
|
|
if decode_tokens.split("\n")[-2] != """这个故事告诉我们,只要我们有决心和毅力,就一定能够克服困难,实现我们的梦想。<|im_end|>""":
|
|
|
|
raise ()
|