From 939be31c1029c2143064e87b138afa7605c9ce64 Mon Sep 17 00:00:00 2001 From: Yiqing-Zhou Date: Sat, 6 May 2023 21:06:18 +0800 Subject: [PATCH] [feature] new arg use_tril_attention_mask --- generate.py | 32 ++++++++++++++++++++++++++------ train.py | 19 +++++++++++++------ 2 files changed, 39 insertions(+), 12 deletions(-) diff --git a/generate.py b/generate.py index ee39652..3951953 100644 --- a/generate.py +++ b/generate.py @@ -32,16 +32,20 @@ def load_tokenizer(model_name_or_path: Union[str, os.PathLike]) -> PreTrainedTok def eval_prompts( - model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prompts: List[str] + model: PreTrainedModel, + tokenizer: PreTrainedTokenizer, + prompts: List[str], + use_tril_attention_mask: bool = False, ) -> List[str]: inputs = tokenizer( prompts, padding=True, return_tensors='pt', return_attention_mask=True ) inputs['position_ids'] = inputs.attention_mask.cumsum(-1) - 1 inputs['position_ids'].masked_fill_(inputs.attention_mask == 0, 1) - inputs['attention_mask'] = ( - inputs.attention_mask.unsqueeze(1) * inputs.attention_mask.unsqueeze(2) - ).tril() + if use_tril_attention_mask: + inputs['attention_mask'] = ( + inputs.attention_mask.unsqueeze(1) * inputs.attention_mask.unsqueeze(2) + ).tril() inputs = inputs.to(model.device) with torch.inference_mode(): output_ids = model.generate( @@ -66,6 +70,17 @@ def parse_args(): help="Name of or path to model", default='gpt2', ) + parser.add_argument( + "--use_tril_attention_mask", + help="Use tril attention mask during training", + action="store_true", + ) + parser.add_argument( + "--tokenizer_name_or_path", + type=str, + help="Name of or path to tokenizer", + default=None, + ) args = parser.parse_args() return args @@ -73,10 +88,13 @@ def parse_args(): if __name__ == '__main__': args = parse_args() + if args.tokenizer_name_or_path is None: + args.tokenizer_name_or_path = args.model_name_or_path + device = torch.device(0) model = load_model(args.model_name_or_path) - tokenizer = load_tokenizer(args.model_name_or_path) + tokenizer = load_tokenizer(args.tokenizer_name_or_path) model = model.to(device) prompts = [ @@ -87,7 +105,9 @@ if __name__ == '__main__': "这是一个最好的时代,这是一个最坏的时代。", "这是一个最好的时代,这是一个最坏的", ] - completes = eval_prompts(model, tokenizer, prompts) + completes = eval_prompts( + model, tokenizer, prompts, use_tril_attention_mask=args.use_tril_attention_mask + ) for prompt, complete in zip(prompts, completes): print("[p]", prompt) diff --git a/train.py b/train.py index 1ebee3e..b6982dc 100644 --- a/train.py +++ b/train.py @@ -103,6 +103,11 @@ def parse_args(): help="Name of or path to model", default='gpt2', ) + parser.add_argument( + "--use_tril_attention_mask", + help="Use tril attention mask during training", + action="store_true", + ) parser.add_argument("--fp16", help="Enable fp16", action="store_true") parser.add_argument("--bf16", help="Enable bf16", action="store_true") parser.add_argument( @@ -165,10 +170,11 @@ def parse_args(): class LitModule(pl.LightningModule): - def __init__(self, model_name: str): + def __init__(self, model_name: str, use_tril_attention_mask: str = False): super().__init__() self.save_hyperparameters() self.llm = self.register_core_module(init_model(model_name)) + self.use_tril_attention_mask = use_tril_attention_mask self.metric_loss = torchmetrics.MeanMetric() self.metric_accuracy = torchmetrics.Accuracy( task='multiclass', @@ -176,7 +182,7 @@ class LitModule(pl.LightningModule): ) @cache - def get_tril_matrix( + def get_batch_tril_matrix( self, block_size: int, batch_size: Optional[int] = None ) -> torch.Tensor: matrix = torch.ones(block_size, block_size).tril() @@ -190,9 +196,10 @@ class LitModule(pl.LightningModule): def training_step(self, batch: Dict[str, torch.Tensor], batch_idx): batch_size, block_size = batch['input_ids'].shape - batch['attention_mask'] = self.get_tril_matrix( - block_size, batch_size=batch_size - ).to(self.device) + if self.use_tril_attention_mask: + batch['attention_mask'] = self.get_batch_tril_matrix( + block_size, batch_size=batch_size + ).to(self.device) outputs = self.llm(**batch, return_dict=True) loss = outputs.loss @@ -244,7 +251,7 @@ if __name__ == '__main__': set_seed(args.seed) # lightning module - lit_module = LitModule(args.model_name) + lit_module = LitModule(args.model_name, args.use_tril_attention_mask) # datasets tokenizer = load_tokenizer(args.tokenizer_name_or_path)