Witllm/qwen/demo.py

46 lines
1.3 KiB
Python
Raw Normal View History

2024-01-03 21:03:27 +08:00
import torch
2024-01-03 20:26:26 +08:00
from modelscope import snapshot_download
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
2024-01-03 21:03:27 +08:00
from transformers import AutoConfig
from modeling_qwen import QWenLMHeadModel
seed = 4321
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
2024-01-03 20:26:26 +08:00
2024-01-05 11:49:35 +08:00
model_dir = snapshot_download("qwen/Qwen-1_8B-Chat")
# model_dir = "/home/colin/.cache/modelscope/hub/qwen/Qwen-1_8B-Chat"
2024-01-03 20:26:26 +08:00
2024-01-03 21:03:27 +08:00
config, kwargs = AutoConfig.from_pretrained(
model_dir,
return_unused_kwargs=True,
trust_remote_code=True,
code_revision=None,
_commit_hash=None,
)
model = QWenLMHeadModel(config)
2024-01-03 20:26:26 +08:00
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
2024-01-03 21:03:27 +08:00
model = model.from_pretrained(
2024-01-03 20:26:26 +08:00
model_dir, device_map="auto", trust_remote_code=True
).train()
# model.train()
# model.zero_grad()
2024-01-03 20:26:26 +08:00
# 可指定不同的生成长度、top_p等相关超参
model.generation_config = GenerationConfig.from_pretrained(
model_dir, trust_remote_code=True
)
# 第一轮对话
response, history = model.chat(tokenizer, "你好", history=None)
print(response)
# 你好!很高兴为你提供帮助。
# 第二轮对话
response, history = model.chat(tokenizer, "给我讲一个年轻人奋斗创业最终取得成功的故事。", history=history)
print(response)