Witllm/chatglm/demo.py

82 lines
2.8 KiB
Python
Raw Normal View History

2024-01-03 20:26:26 +08:00
import sys
sys.path.append("..")
2023-12-21 16:53:47 +08:00
import json
2023-12-25 16:22:45 +08:00
import torch
2023-12-21 16:53:47 +08:00
2024-01-03 20:26:26 +08:00
from modeling_chatglm import ChatGLMForConditionalGeneration
from tokenization_chatglm import ChatGLMTokenizer
from modelscope import snapshot_download
from transformers import AutoConfig
2023-12-21 16:53:47 +08:00
2023-12-25 22:53:53 +08:00
from tools import show
2023-12-26 14:08:02 +08:00
seed = 4321
2023-12-25 16:22:45 +08:00
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
2024-01-03 20:26:26 +08:00
pretrained_model_name_or_path = snapshot_download("ZhipuAI/chatglm3-6b")
2023-12-25 22:53:53 +08:00
2023-12-21 16:53:47 +08:00
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path,
return_unused_kwargs=True,
trust_remote_code=True,
code_revision=None,
_commit_hash=None,
)
glm = ChatGLMForConditionalGeneration(config)
2024-01-03 20:26:26 +08:00
tokenizer_config_file = "./tokenizer_config.json"
2023-12-21 16:53:47 +08:00
if tokenizer_config_file is not None:
with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
init_kwargs = json.load(tokenizer_config_handle)
init_kwargs.pop("tokenizer_class", None)
init_kwargs.pop("tokenizer_file", None)
saved_init_inputs = init_kwargs.pop("init_inputs", ())
init_inputs = saved_init_inputs
2024-01-03 20:26:26 +08:00
init_kwargs["vocab_file"] = "./tokenizer.model"
2023-12-21 16:53:47 +08:00
init_kwargs["added_tokens_file"] = None
init_kwargs["special_tokens_map_file"] = None
init_kwargs["tokenizer_file"] = None
init_kwargs["name_or_path"] = pretrained_model_name_or_path
tokenizer = ChatGLMTokenizer(*init_inputs, **init_kwargs)
2023-12-25 17:26:19 +08:00
glm = glm.from_pretrained(pretrained_model_name_or_path).half().cuda()
2023-12-21 16:53:47 +08:00
glm = glm.eval()
2023-12-25 22:53:53 +08:00
query = "你好"
2023-12-22 18:01:57 +08:00
response, history = glm.chat(tokenizer, query, history=[])
2023-12-21 16:53:47 +08:00
print(response)
2023-12-27 19:58:52 +08:00
if response.split("\n")[-1] != " 你好!有什么可以帮助您的吗?":
2023-12-25 16:22:45 +08:00
raise ()
2023-12-25 22:53:53 +08:00
# query = "colin"
# response, history = glm.chat(tokenizer, query, history=history)
# print(response)
# if response[1:] != " Hello! How can I assist you today":
# raise ()
2023-12-25 16:22:45 +08:00
2023-12-21 16:53:47 +08:00
# response, history = glm.chat(tokenizer, "你是一个心理学专家,请问晚上睡不着应该怎么办", history=history)
# print(response)
# import plotly_express as px
# px.imshow(ron)
# gapminder = px.data.gapminder()
# gapminder2007 = gapminder.query('year == 2007')
# px.scatter(gapminder2007, x='gdpPercap', y='lifeExp')
# from modelscope import AutoTokenizer, AutoModel, snapshot_download
# model_dir = snapshot_download("ZhipuAI/chatglm3-6b", cache_dir="./chatglm", revision="v1.0.0")
# model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).half().cuda()
# tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
# model = model.eval()
# response, history = model.chat(tokenizer, "colin", history=[])
# print(response)
# response, history = model.chat(tokenizer, "你好", history=history)
# print(response)
# # response, history = model.chat(tokenizer, "你是一个心理学专家,请问晚上睡不着应该怎么办", history=history)
# # print(response)