gpt-pretrain/generate.py

83 lines
2.6 KiB
Python
Raw Normal View History

2023-05-04 21:52:25 +08:00
import argparse
2023-05-07 13:01:02 +08:00
from typing import List
2023-05-04 21:52:25 +08:00
import torch
2023-05-07 13:01:02 +08:00
from transformers import PreTrainedModel, PreTrainedTokenizer
2023-05-04 21:52:25 +08:00
2023-05-07 13:01:02 +08:00
from utils import load_model, load_tokenizer
2023-05-04 21:52:25 +08:00
def eval_prompts(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
prompts: List[str],
use_tril_attention_mask: bool = False,
2023-05-04 21:52:25 +08:00
) -> List[str]:
2023-05-28 20:02:56 +08:00
inputs = tokenizer(prompts, padding=True, return_tensors='pt', return_attention_mask=True)
2023-05-04 21:52:25 +08:00
inputs['position_ids'] = inputs.attention_mask.cumsum(-1) - 1
inputs['position_ids'].masked_fill_(inputs.attention_mask == 0, 1)
if use_tril_attention_mask:
2023-05-28 20:02:56 +08:00
inputs['attention_mask'] = (inputs.attention_mask.unsqueeze(1) * inputs.attention_mask.unsqueeze(2)).tril()
2023-05-04 21:52:25 +08:00
inputs = inputs.to(model.device)
with torch.inference_mode():
output_ids = model.generate(
**inputs,
do_sample=False,
num_beams=16,
max_new_tokens=100,
eos_token_id=tokenizer.eos_token_id,
early_stopping=True,
)
2023-05-28 20:02:56 +08:00
completes = tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
2023-05-04 21:52:25 +08:00
return completes
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name_or_path",
type=str,
help="Name of or path to model",
default='gpt2',
)
parser.add_argument(
"--use_tril_attention_mask",
help="Use tril attention mask during training",
action="store_true",
)
parser.add_argument(
"--tokenizer_name_or_path",
type=str,
help="Name of or path to tokenizer",
default=None,
)
2023-05-04 21:52:25 +08:00
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
if args.tokenizer_name_or_path is None:
args.tokenizer_name_or_path = args.model_name_or_path
2023-05-04 21:52:25 +08:00
device = torch.device(0)
model = load_model(args.model_name_or_path)
tokenizer = load_tokenizer(args.tokenizer_name_or_path)
2023-05-04 21:52:25 +08:00
model = model.to(device)
prompts = [
"Shall I compare thee to a summer's day? Thou art more lovely and more temperate.",
"Shall I compare thee to a summer's day? Thou art more lovely and",
"Belle! C'est un mot qu'on dirait inventé pour elle.",
"Belle! C'est un mot qu'on dirait inventé",
"这是一个最好的时代,这是一个最坏的时代。",
"这是一个最好的时代,这是一个最坏的",
]
2023-05-28 20:02:56 +08:00
completes = eval_prompts(model, tokenizer, prompts, use_tril_attention_mask=args.use_tril_attention_mask)
2023-05-04 21:52:25 +08:00
for prompt, complete in zip(prompts, completes):
print("[p]", prompt)
print("[c]", complete)