[feature] new arg use_tril_attention_mask

This commit is contained in:
Yiqing-Zhou 2023-05-06 21:06:18 +08:00
parent 0324eb4103
commit 939be31c10
2 changed files with 39 additions and 12 deletions

View File

@ -32,13 +32,17 @@ def load_tokenizer(model_name_or_path: Union[str, os.PathLike]) -> PreTrainedTok
def eval_prompts( def eval_prompts(
model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prompts: List[str] model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
prompts: List[str],
use_tril_attention_mask: bool = False,
) -> List[str]: ) -> List[str]:
inputs = tokenizer( inputs = tokenizer(
prompts, padding=True, return_tensors='pt', return_attention_mask=True prompts, padding=True, return_tensors='pt', return_attention_mask=True
) )
inputs['position_ids'] = inputs.attention_mask.cumsum(-1) - 1 inputs['position_ids'] = inputs.attention_mask.cumsum(-1) - 1
inputs['position_ids'].masked_fill_(inputs.attention_mask == 0, 1) inputs['position_ids'].masked_fill_(inputs.attention_mask == 0, 1)
if use_tril_attention_mask:
inputs['attention_mask'] = ( inputs['attention_mask'] = (
inputs.attention_mask.unsqueeze(1) * inputs.attention_mask.unsqueeze(2) inputs.attention_mask.unsqueeze(1) * inputs.attention_mask.unsqueeze(2)
).tril() ).tril()
@ -66,6 +70,17 @@ def parse_args():
help="Name of or path to model", help="Name of or path to model",
default='gpt2', default='gpt2',
) )
parser.add_argument(
"--use_tril_attention_mask",
help="Use tril attention mask during training",
action="store_true",
)
parser.add_argument(
"--tokenizer_name_or_path",
type=str,
help="Name of or path to tokenizer",
default=None,
)
args = parser.parse_args() args = parser.parse_args()
return args return args
@ -73,10 +88,13 @@ def parse_args():
if __name__ == '__main__': if __name__ == '__main__':
args = parse_args() args = parse_args()
if args.tokenizer_name_or_path is None:
args.tokenizer_name_or_path = args.model_name_or_path
device = torch.device(0) device = torch.device(0)
model = load_model(args.model_name_or_path) model = load_model(args.model_name_or_path)
tokenizer = load_tokenizer(args.model_name_or_path) tokenizer = load_tokenizer(args.tokenizer_name_or_path)
model = model.to(device) model = model.to(device)
prompts = [ prompts = [
@ -87,7 +105,9 @@ if __name__ == '__main__':
"这是一个最好的时代,这是一个最坏的时代。", "这是一个最好的时代,这是一个最坏的时代。",
"这是一个最好的时代,这是一个最坏的", "这是一个最好的时代,这是一个最坏的",
] ]
completes = eval_prompts(model, tokenizer, prompts) completes = eval_prompts(
model, tokenizer, prompts, use_tril_attention_mask=args.use_tril_attention_mask
)
for prompt, complete in zip(prompts, completes): for prompt, complete in zip(prompts, completes):
print("[p]", prompt) print("[p]", prompt)

View File

@ -103,6 +103,11 @@ def parse_args():
help="Name of or path to model", help="Name of or path to model",
default='gpt2', default='gpt2',
) )
parser.add_argument(
"--use_tril_attention_mask",
help="Use tril attention mask during training",
action="store_true",
)
parser.add_argument("--fp16", help="Enable fp16", action="store_true") parser.add_argument("--fp16", help="Enable fp16", action="store_true")
parser.add_argument("--bf16", help="Enable bf16", action="store_true") parser.add_argument("--bf16", help="Enable bf16", action="store_true")
parser.add_argument( parser.add_argument(
@ -165,10 +170,11 @@ def parse_args():
class LitModule(pl.LightningModule): class LitModule(pl.LightningModule):
def __init__(self, model_name: str): def __init__(self, model_name: str, use_tril_attention_mask: str = False):
super().__init__() super().__init__()
self.save_hyperparameters() self.save_hyperparameters()
self.llm = self.register_core_module(init_model(model_name)) self.llm = self.register_core_module(init_model(model_name))
self.use_tril_attention_mask = use_tril_attention_mask
self.metric_loss = torchmetrics.MeanMetric() self.metric_loss = torchmetrics.MeanMetric()
self.metric_accuracy = torchmetrics.Accuracy( self.metric_accuracy = torchmetrics.Accuracy(
task='multiclass', task='multiclass',
@ -176,7 +182,7 @@ class LitModule(pl.LightningModule):
) )
@cache @cache
def get_tril_matrix( def get_batch_tril_matrix(
self, block_size: int, batch_size: Optional[int] = None self, block_size: int, batch_size: Optional[int] = None
) -> torch.Tensor: ) -> torch.Tensor:
matrix = torch.ones(block_size, block_size).tril() matrix = torch.ones(block_size, block_size).tril()
@ -190,7 +196,8 @@ class LitModule(pl.LightningModule):
def training_step(self, batch: Dict[str, torch.Tensor], batch_idx): def training_step(self, batch: Dict[str, torch.Tensor], batch_idx):
batch_size, block_size = batch['input_ids'].shape batch_size, block_size = batch['input_ids'].shape
batch['attention_mask'] = self.get_tril_matrix( if self.use_tril_attention_mask:
batch['attention_mask'] = self.get_batch_tril_matrix(
block_size, batch_size=batch_size block_size, batch_size=batch_size
).to(self.device) ).to(self.device)
outputs = self.llm(**batch, return_dict=True) outputs = self.llm(**batch, return_dict=True)
@ -244,7 +251,7 @@ if __name__ == '__main__':
set_seed(args.seed) set_seed(args.seed)
# lightning module # lightning module
lit_module = LitModule(args.model_name) lit_module = LitModule(args.model_name, args.use_tril_attention_mask)
# datasets # datasets
tokenizer = load_tokenizer(args.tokenizer_name_or_path) tokenizer = load_tokenizer(args.tokenizer_name_or_path)