[feature] new arg use_tril_attention_mask
This commit is contained in:
parent
0324eb4103
commit
939be31c10
32
generate.py
32
generate.py
|
@ -32,16 +32,20 @@ def load_tokenizer(model_name_or_path: Union[str, os.PathLike]) -> PreTrainedTok
|
||||||
|
|
||||||
|
|
||||||
def eval_prompts(
|
def eval_prompts(
|
||||||
model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prompts: List[str]
|
model: PreTrainedModel,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
prompts: List[str],
|
||||||
|
use_tril_attention_mask: bool = False,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
inputs = tokenizer(
|
inputs = tokenizer(
|
||||||
prompts, padding=True, return_tensors='pt', return_attention_mask=True
|
prompts, padding=True, return_tensors='pt', return_attention_mask=True
|
||||||
)
|
)
|
||||||
inputs['position_ids'] = inputs.attention_mask.cumsum(-1) - 1
|
inputs['position_ids'] = inputs.attention_mask.cumsum(-1) - 1
|
||||||
inputs['position_ids'].masked_fill_(inputs.attention_mask == 0, 1)
|
inputs['position_ids'].masked_fill_(inputs.attention_mask == 0, 1)
|
||||||
inputs['attention_mask'] = (
|
if use_tril_attention_mask:
|
||||||
inputs.attention_mask.unsqueeze(1) * inputs.attention_mask.unsqueeze(2)
|
inputs['attention_mask'] = (
|
||||||
).tril()
|
inputs.attention_mask.unsqueeze(1) * inputs.attention_mask.unsqueeze(2)
|
||||||
|
).tril()
|
||||||
inputs = inputs.to(model.device)
|
inputs = inputs.to(model.device)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
output_ids = model.generate(
|
output_ids = model.generate(
|
||||||
|
@ -66,6 +70,17 @@ def parse_args():
|
||||||
help="Name of or path to model",
|
help="Name of or path to model",
|
||||||
default='gpt2',
|
default='gpt2',
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_tril_attention_mask",
|
||||||
|
help="Use tril attention mask during training",
|
||||||
|
action="store_true",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokenizer_name_or_path",
|
||||||
|
type=str,
|
||||||
|
help="Name of or path to tokenizer",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
@ -73,10 +88,13 @@ def parse_args():
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
|
if args.tokenizer_name_or_path is None:
|
||||||
|
args.tokenizer_name_or_path = args.model_name_or_path
|
||||||
|
|
||||||
device = torch.device(0)
|
device = torch.device(0)
|
||||||
|
|
||||||
model = load_model(args.model_name_or_path)
|
model = load_model(args.model_name_or_path)
|
||||||
tokenizer = load_tokenizer(args.model_name_or_path)
|
tokenizer = load_tokenizer(args.tokenizer_name_or_path)
|
||||||
|
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
prompts = [
|
prompts = [
|
||||||
|
@ -87,7 +105,9 @@ if __name__ == '__main__':
|
||||||
"这是一个最好的时代,这是一个最坏的时代。",
|
"这是一个最好的时代,这是一个最坏的时代。",
|
||||||
"这是一个最好的时代,这是一个最坏的",
|
"这是一个最好的时代,这是一个最坏的",
|
||||||
]
|
]
|
||||||
completes = eval_prompts(model, tokenizer, prompts)
|
completes = eval_prompts(
|
||||||
|
model, tokenizer, prompts, use_tril_attention_mask=args.use_tril_attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
for prompt, complete in zip(prompts, completes):
|
for prompt, complete in zip(prompts, completes):
|
||||||
print("[p]", prompt)
|
print("[p]", prompt)
|
||||||
|
|
19
train.py
19
train.py
|
@ -103,6 +103,11 @@ def parse_args():
|
||||||
help="Name of or path to model",
|
help="Name of or path to model",
|
||||||
default='gpt2',
|
default='gpt2',
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_tril_attention_mask",
|
||||||
|
help="Use tril attention mask during training",
|
||||||
|
action="store_true",
|
||||||
|
)
|
||||||
parser.add_argument("--fp16", help="Enable fp16", action="store_true")
|
parser.add_argument("--fp16", help="Enable fp16", action="store_true")
|
||||||
parser.add_argument("--bf16", help="Enable bf16", action="store_true")
|
parser.add_argument("--bf16", help="Enable bf16", action="store_true")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -165,10 +170,11 @@ def parse_args():
|
||||||
|
|
||||||
|
|
||||||
class LitModule(pl.LightningModule):
|
class LitModule(pl.LightningModule):
|
||||||
def __init__(self, model_name: str):
|
def __init__(self, model_name: str, use_tril_attention_mask: str = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.save_hyperparameters()
|
self.save_hyperparameters()
|
||||||
self.llm = self.register_core_module(init_model(model_name))
|
self.llm = self.register_core_module(init_model(model_name))
|
||||||
|
self.use_tril_attention_mask = use_tril_attention_mask
|
||||||
self.metric_loss = torchmetrics.MeanMetric()
|
self.metric_loss = torchmetrics.MeanMetric()
|
||||||
self.metric_accuracy = torchmetrics.Accuracy(
|
self.metric_accuracy = torchmetrics.Accuracy(
|
||||||
task='multiclass',
|
task='multiclass',
|
||||||
|
@ -176,7 +182,7 @@ class LitModule(pl.LightningModule):
|
||||||
)
|
)
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def get_tril_matrix(
|
def get_batch_tril_matrix(
|
||||||
self, block_size: int, batch_size: Optional[int] = None
|
self, block_size: int, batch_size: Optional[int] = None
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
matrix = torch.ones(block_size, block_size).tril()
|
matrix = torch.ones(block_size, block_size).tril()
|
||||||
|
@ -190,9 +196,10 @@ class LitModule(pl.LightningModule):
|
||||||
|
|
||||||
def training_step(self, batch: Dict[str, torch.Tensor], batch_idx):
|
def training_step(self, batch: Dict[str, torch.Tensor], batch_idx):
|
||||||
batch_size, block_size = batch['input_ids'].shape
|
batch_size, block_size = batch['input_ids'].shape
|
||||||
batch['attention_mask'] = self.get_tril_matrix(
|
if self.use_tril_attention_mask:
|
||||||
block_size, batch_size=batch_size
|
batch['attention_mask'] = self.get_batch_tril_matrix(
|
||||||
).to(self.device)
|
block_size, batch_size=batch_size
|
||||||
|
).to(self.device)
|
||||||
outputs = self.llm(**batch, return_dict=True)
|
outputs = self.llm(**batch, return_dict=True)
|
||||||
loss = outputs.loss
|
loss = outputs.loss
|
||||||
|
|
||||||
|
@ -244,7 +251,7 @@ if __name__ == '__main__':
|
||||||
set_seed(args.seed)
|
set_seed(args.seed)
|
||||||
|
|
||||||
# lightning module
|
# lightning module
|
||||||
lit_module = LitModule(args.model_name)
|
lit_module = LitModule(args.model_name, args.use_tril_attention_mask)
|
||||||
|
|
||||||
# datasets
|
# datasets
|
||||||
tokenizer = load_tokenizer(args.tokenizer_name_or_path)
|
tokenizer = load_tokenizer(args.tokenizer_name_or_path)
|
||||||
|
|
Loading…
Reference in New Issue