Witllm/seqgpt/demo.py

70 lines
2.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# from modelscope.utils.constant import Tasks
# from modelscope.pipelines import pipeline
# prompt = "输入: 中国的首都在哪里\n输出: "
# # task可选值为 抽取、分类。text为需要分析的文本。labels为类型列表中文逗号分隔。
# inputs = {'task': '抽取', 'text': '杭州欢迎你。', 'labels': '地名'}
# # PROMPT_TEMPLATE保持不变
# PROMPT_TEMPLATE = '输入: {text}\n{task}: {labels}\n输出: '
# # prompt = PROMPT_TEMPLATE.format(**inputs)
# pipeline_ins = pipeline(task=Tasks.text_generation, model='damo/nlp_seqgpt-560m', model_revision = 'v1.0.1', run_kwargs={'gen_token': '[GEN]'})
# print(pipeline_ins(prompt))
import torch
import json
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.preprocessors import TextGenerationTransformersPreprocessor
from modelscope.utils.config import Config
from modelscope import snapshot_download
from transformers import AutoConfig
from modeling_bloom import BloomForCausalLM
from tokenization_bloom_fast import BloomTokenizerFast
seed = 4321
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
model_dir = snapshot_download("damo/nlp_seqgpt-560m")
config, kwargs = AutoConfig.from_pretrained(
model_dir,
return_unused_kwargs=True,
trust_remote_code=True,
code_revision=None,
_commit_hash=None,
)
tokenizer_config_file = "./tokenizer_config.json"
if tokenizer_config_file is not None:
with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
init_kwargs = json.load(tokenizer_config_handle)
init_kwargs.pop("tokenizer_class", None)
init_kwargs.pop("tokenizer_file", None)
saved_init_inputs = init_kwargs.pop("init_inputs", ())
init_inputs = saved_init_inputs
init_kwargs["vocab_file"] = None
init_kwargs["added_tokens_file"] = None
init_kwargs["special_tokens_map_file"] = "./special_tokens_map.json"
init_kwargs["tokenizer_file"] = "./tokenizer.json"
init_kwargs["name_or_path"] = model_dir
tokenizer = BloomTokenizerFast(*init_inputs, **init_kwargs)
model = BloomForCausalLM(config)
model = model.from_pretrained(model_dir).cuda().train()
prompt = "输入: 中国的首都在哪里\n输出: "
prompt = "输入: 美国的首都在哪里\n输出: "
input_ids = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024)
input_ids = input_ids.input_ids.cuda()
outputs = model.generate(input_ids, num_beams=4, do_sample=False, max_new_tokens=256)
decoded_sentences = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print(decoded_sentences[0])