Witllm/prompt_clue/demo.py

90 lines
4.3 KiB
Python
Raw Normal View History

2024-01-06 21:05:39 +08:00
import torch
2024-01-05 20:33:01 +08:00
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.preprocessors import TextGenerationTransformersPreprocessor
from modeling_t5 import T5ForConditionalGeneration
2024-01-06 21:05:39 +08:00
from modelscope.utils.config import Config
from configuration import T5Config
from modelscope import snapshot_download
from transformers import AutoConfig
2024-01-05 20:33:01 +08:00
2024-01-06 21:05:39 +08:00
seed = 4321
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
model_dir = snapshot_download("ClueAI/PromptCLUE-base-v1-5")
# config = T5Config()
config, kwargs = AutoConfig.from_pretrained(
model_dir,
return_unused_kwargs=True,
trust_remote_code=True,
code_revision=None,
_commit_hash=None,
2024-01-05 20:33:01 +08:00
)
2024-01-06 21:05:39 +08:00
model = T5ForConditionalGeneration(config)
model = model.from_pretrained(model_dir)
preprocessor = TextGenerationTransformersPreprocessor(model_dir)
out = preprocessor._tokenize_text("生成与下列文字相同意")
tokenizer = preprocessor.nlp_tokenizer
response, history = model.chat(tokenizer, "生成与下列文字相同意", history=[])
# model = T5ForConditionalGeneration.from_pretrained(
# "ClueAI/PromptCLUE-base-v1-5", revision="v0.1"
# )
2024-01-05 20:33:01 +08:00
pipeline_t2t = pipeline(
task=Tasks.text2text_generation, model=model, preprocessor=preprocessor
)
print(pipeline_t2t("生成与下列文字相同意思的句子:\n白云遍地无人扫\n答案:", do_sample=True, top_p=0.8))
# {'text': '白云散去无踪,没人扫。'}
# print(pipeline_t2t('改写下面的文字,确保意思相同:\n一个如此藐视本国人民民主权利的人怎么可能捍卫外国人的民权\n答案', do_sample=True, top_p=0.8))
# # {'text': '对一个如此藐视本国人民民主权利的人,怎么能捍卫外国人的民权?'}
# print(pipeline_t2t('根据问题给出答案:\n问题手指发麻的主要可能病因是\n答案'))
# # {'text': '神经损伤,颈椎病,贫血,高血压'}
# print(pipeline_t2t('问答:\n问题黄果悬钩子的目是\n答案'))
# # {'text': '蔷薇目'}
# print(pipeline_t2t('情感分析:\n这个看上去还可以但其实我不喜欢\n选项积极消极'))
# # {'text': '消极'}
# print(pipeline_t2t("下面句子是否表示了相同的语义:\n文本1糖尿病腿麻木怎么办\n文本2糖尿病怎样控制生活方式\n选项相似不相似\n答案"))
# # {'text': '不相似'}
# print(pipeline_t2t('这是关于哪方面的新闻:\n如果日本沉没中国会接收日本难民吗\n选项故事,文化,娱乐,体育,财经,房产,汽车,教育,科技,军事,旅游,国际,股票,农业,游戏'))
# # {'text': '国际'}
# print(pipeline_t2t("阅读文本抽取关键信息:\n张玄武1990年出生中国国籍无境外居留权博士学历现任杭州线锁科技技术总监。\n问题机构人名职位籍贯专业国籍学历种族\n答案"))
# # {'text': '机构杭州线锁科技技术_人名张玄武_职位博士学历'}
# print(pipeline_t2t("翻译成英文:\n杀不死我的只会让我更强大\n答案"))
# # {'text': 'To kill my life only let me stronger'}
# print(pipeline_t2t('为下面的文章生成摘要:\n北京时间9月5日12时52分四川甘孜藏族自治州泸定县发生6.8级地震。地震发生后,领导高度重视并作出重要指示,要求把抢救生命作为首要任务,全力救援受灾群众,最大限度减少人员伤亡'))
# # {'text': '四川甘孜发生6.8级地震'}
# print(pipeline_t2t("推理关系判断:\n前提小明今天在北京\n假设小明在深圳旅游\n选项矛盾蕴含中立\n答案"))
# # {'text': '蕴涵'}
# print(pipeline_t2t('阅读以下对话并回答问题。\n男今天怎么这么晚才来上班啊昨天工作到很晚而且我还感冒了。男那你回去休息吧我帮你请假。女谢谢你。\n问题女的怎么样\n选项正在工作感冒了在打电话要出差。'))
# # {'text': '感冒了'}
# print(pipeline_t2t("文本纠错:\n告诉二营长叫他彻回来我李云龙从不打没有准备的杖\n答案"))
# #{'text''告诉二营长,叫他下来,我李云龙从不打没有准备的仗'}
# print(pipeline_t2t("问答:\n问题小米的创始人是谁\n答案"))
# # {'text': '小米创始人:雷军'}