import argparse import os from typing import List, Union import torch from transformers import ( AutoModel, AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer, ) def load_model(model_name_or_path: Union[str, os.PathLike]) -> PreTrainedModel: try: model = AutoModelForCausalLM.from_pretrained( model_name_or_path, trust_remote_code=True ) except ValueError: model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True) return model def load_tokenizer(model_name_or_path: Union[str, os.PathLike]) -> PreTrainedTokenizer: tokenizer = AutoTokenizer.from_pretrained( model_name_or_path, padding_side='left', trust_remote_code=True ) if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token return tokenizer def eval_prompts( model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prompts: List[str] ) -> List[str]: inputs = tokenizer( prompts, padding=True, return_tensors='pt', return_attention_mask=True ) inputs['position_ids'] = inputs.attention_mask.cumsum(-1) - 1 inputs['position_ids'].masked_fill_(inputs.attention_mask == 0, 1) inputs['attention_mask'] = ( inputs.attention_mask.unsqueeze(1) * inputs.attention_mask.unsqueeze(2) ).tril() inputs = inputs.to(model.device) with torch.inference_mode(): output_ids = model.generate( **inputs, do_sample=False, num_beams=16, max_new_tokens=100, eos_token_id=tokenizer.eos_token_id, early_stopping=True, ) completes = tokenizer.batch_decode( output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False ) return completes def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( "--model_name_or_path", type=str, help="Name of or path to model", default='gpt2', ) args = parser.parse_args() return args if __name__ == '__main__': args = parse_args() device = torch.device(0) model = load_model(args.model_name_or_path) tokenizer = load_tokenizer(args.model_name_or_path) model = model.to(device) prompts = [ "Shall I compare thee to a summer's day? Thou art more lovely and more temperate.", "Shall I compare thee to a summer's day? Thou art more lovely and", "Belle! C'est un mot qu'on dirait inventé pour elle.", "Belle! C'est un mot qu'on dirait inventé", "这是一个最好的时代,这是一个最坏的时代。", "这是一个最好的时代,这是一个最坏的", ] completes = eval_prompts(model, tokenizer, prompts) for prompt, complete in zip(prompts, completes): print("[p]", prompt) print("[c]", complete)