[code] formatter-caused changes

This commit is contained in:
yiqing-zhou 2023-05-28 20:02:56 +08:00
parent 10a88a5012
commit b76d333f39
7 changed files with 30 additions and 92 deletions

View File

@ -8,9 +8,7 @@ class _LazyAutoMapping(auto_factory._LazyAutoMapping):
def _load_attr_from_module(self, model_type, attr): def _load_attr_from_module(self, model_type, attr):
module_name = auto_factory.model_type_to_module_name(model_type) module_name = auto_factory.model_type_to_module_name(model_type)
if module_name not in self._modules: if module_name not in self._modules:
self._modules[module_name] = importlib.import_module( self._modules[module_name] = importlib.import_module(f".{module_name}", "custom_models")
f".{module_name}", "custom_models"
)
return auto_factory.getattribute_from_module(self._modules[module_name], attr) return auto_factory.getattribute_from_module(self._modules[module_name], attr)
@ -28,9 +26,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
) )
MODEL_MAPPING = _LazyAutoMapping( MODEL_MAPPING = _LazyAutoMapping(configuration_auto.CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
configuration_auto.CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES
)
MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping( MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(

View File

@ -43,25 +43,15 @@ class GPT2Model(transformers.models.gpt2.GPT2Model):
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = ( output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = ( output_hidden_states = (
output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
) )
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = ( return_dict = return_dict if return_dict is not None else self.config.use_return_dict
return_dict if return_dict is not None else self.config.use_return_dict
)
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError( raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif input_ids is not None: elif input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1]) input_ids = input_ids.view(-1, input_shape[-1])
@ -107,9 +97,7 @@ class GPT2Model(transformers.models.gpt2.GPT2Model):
elif attention_mask.dim() == 3: elif attention_mask.dim() == 3:
attention_mask = attention_mask[:, None, ...] attention_mask = attention_mask[:, None, ...]
else: else:
raise ValueError( raise ValueError(f"attention_mask.dim() is {attention_mask.dim()}, should be 2 or 3")
f"attention_mask.dim() is {attention_mask.dim()}, should be 2 or 3"
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for # masked positions, this operation will create a tensor which is 0.0 for
@ -162,9 +150,7 @@ class GPT2Model(transformers.models.gpt2.GPT2Model):
presents = () if use_cache else None presents = () if use_cache else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_cross_attentions = ( all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
() if output_attentions and self.config.add_cross_attention else None
)
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
# Model parallel # Model parallel
@ -172,9 +158,7 @@ class GPT2Model(transformers.models.gpt2.GPT2Model):
torch.cuda.set_device(hidden_states.device) torch.cuda.set_device(hidden_states.device)
# Ensure layer_past is on same device as hidden_states (might not be correct) # Ensure layer_past is on same device as hidden_states (might not be correct)
if layer_past is not None: if layer_past is not None:
layer_past = tuple( layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
past_state.to(hidden_states.device) for past_state in layer_past
)
# Ensure that attention_mask is always on the same device as hidden_states # Ensure that attention_mask is always on the same device as hidden_states
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask.to(hidden_states.device) attention_mask = attention_mask.to(hidden_states.device)
@ -218,13 +202,9 @@ class GPT2Model(transformers.models.gpt2.GPT2Model):
presents = presents + (outputs[1],) presents = presents + (outputs[1],)
if output_attentions: if output_attentions:
all_self_attentions = all_self_attentions + ( all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
outputs[2 if use_cache else 1],
)
if self.config.add_cross_attention: if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + ( all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
outputs[3 if use_cache else 2],
)
# Model Parallel: If it's the last layer for that device, put things on the next device # Model Parallel: If it's the last layer for that device, put things on the next device
if self.model_parallel: if self.model_parallel:
@ -274,9 +254,7 @@ class GPT2LMHeadModel(transformers.models.gpt2.GPT2LMHeadModel):
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
def prepare_inputs_for_generation( def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
):
token_type_ids = kwargs.get("token_type_ids", None) token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs # only last token for inputs_ids if past is defined in kwargs
if past_key_values: if past_key_values:
@ -326,9 +304,7 @@ class GPT2LMHeadModel(transformers.models.gpt2.GPT2LMHeadModel):
# update token_type_ids with last value # update token_type_ids with last value
if "token_type_ids" in model_kwargs: if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"] token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = torch.cat( model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
[token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1
)
# update position_ids # update position_ids
if "position_ids" in model_kwargs: if "position_ids" in model_kwargs:
@ -363,9 +339,7 @@ class GPT2LMHeadModel(transformers.models.gpt2.GPT2LMHeadModel):
) )
model_kwargs["attention_mask"] = attention_mask model_kwargs["attention_mask"] = attention_mask
else: else:
raise ValueError( raise ValueError(f"attention_mask.dim() is {attention_mask.dim()}, should be 2 or 3")
f"attention_mask.dim() is {attention_mask.dim()}, should be 2 or 3"
)
else: else:
# update decoder attention mask # update decoder attention mask
if "decoder_attention_mask" in model_kwargs: if "decoder_attention_mask" in model_kwargs:
@ -373,9 +347,7 @@ class GPT2LMHeadModel(transformers.models.gpt2.GPT2LMHeadModel):
model_kwargs["decoder_attention_mask"] = torch.cat( model_kwargs["decoder_attention_mask"] = torch.cat(
[ [
decoder_attention_mask, decoder_attention_mask,
decoder_attention_mask.new_ones( decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1)),
(decoder_attention_mask.shape[0], 1)
),
], ],
dim=-1, dim=-1,
) )

View File

@ -13,15 +13,11 @@ def eval_prompts(
prompts: List[str], prompts: List[str],
use_tril_attention_mask: bool = False, use_tril_attention_mask: bool = False,
) -> List[str]: ) -> List[str]:
inputs = tokenizer( inputs = tokenizer(prompts, padding=True, return_tensors='pt', return_attention_mask=True)
prompts, padding=True, return_tensors='pt', return_attention_mask=True
)
inputs['position_ids'] = inputs.attention_mask.cumsum(-1) - 1 inputs['position_ids'] = inputs.attention_mask.cumsum(-1) - 1
inputs['position_ids'].masked_fill_(inputs.attention_mask == 0, 1) inputs['position_ids'].masked_fill_(inputs.attention_mask == 0, 1)
if use_tril_attention_mask: if use_tril_attention_mask:
inputs['attention_mask'] = ( inputs['attention_mask'] = (inputs.attention_mask.unsqueeze(1) * inputs.attention_mask.unsqueeze(2)).tril()
inputs.attention_mask.unsqueeze(1) * inputs.attention_mask.unsqueeze(2)
).tril()
inputs = inputs.to(model.device) inputs = inputs.to(model.device)
with torch.inference_mode(): with torch.inference_mode():
output_ids = model.generate( output_ids = model.generate(
@ -32,9 +28,7 @@ def eval_prompts(
eos_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id,
early_stopping=True, early_stopping=True,
) )
completes = tokenizer.batch_decode( completes = tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return completes return completes
@ -81,9 +75,7 @@ if __name__ == '__main__':
"这是一个最好的时代,这是一个最坏的时代。", "这是一个最好的时代,这是一个最坏的时代。",
"这是一个最好的时代,这是一个最坏的", "这是一个最好的时代,这是一个最坏的",
] ]
completes = eval_prompts( completes = eval_prompts(model, tokenizer, prompts, use_tril_attention_mask=args.use_tril_attention_mask)
model, tokenizer, prompts, use_tril_attention_mask=args.use_tril_attention_mask
)
for prompt, complete in zip(prompts, completes): for prompt, complete in zip(prompts, completes):
print("[p]", prompt) print("[p]", prompt)

View File

@ -26,8 +26,6 @@ if __name__ == '__main__':
checkpoint_file_path = next(lightning_logs_dir_path.glob("checkpoints/*.ckpt")) checkpoint_file_path = next(lightning_logs_dir_path.glob("checkpoints/*.ckpt"))
lit_module = LitModule.load_from_checkpoint( lit_module = LitModule.load_from_checkpoint(checkpoint_file_path, map_location='cpu')
checkpoint_file_path, map_location='cpu'
)
model: PreTrainedModel = lit_module.__core_module__ model: PreTrainedModel = lit_module.__core_module__
model.save_pretrained(exports_dir_path) model.save_pretrained(exports_dir_path)

View File

@ -27,9 +27,7 @@ class LitModule(pl.LightningModule):
) )
@cache @cache
def get_batch_tril_matrix( def get_batch_tril_matrix(self, block_size: int, batch_size: Optional[int] = None) -> torch.Tensor:
self, block_size: int, batch_size: Optional[int] = None
) -> torch.Tensor:
matrix = torch.ones(block_size, block_size).tril() matrix = torch.ones(block_size, block_size).tril()
if batch_size is not None: if batch_size is not None:
matrix = matrix.repeat(batch_size, 1, 1) matrix = matrix.repeat(batch_size, 1, 1)
@ -42,9 +40,7 @@ class LitModule(pl.LightningModule):
def training_step(self, batch: Dict[str, torch.Tensor], batch_idx): def training_step(self, batch: Dict[str, torch.Tensor], batch_idx):
batch_size, block_size = batch['input_ids'].shape batch_size, block_size = batch['input_ids'].shape
if self.use_tril_attention_mask: if self.use_tril_attention_mask:
batch['attention_mask'] = self.get_batch_tril_matrix( batch['attention_mask'] = self.get_batch_tril_matrix(block_size, batch_size=batch_size).to(self.device)
block_size, batch_size=batch_size
).to(self.device)
outputs = self.llm(**batch, return_dict=True) outputs = self.llm(**batch, return_dict=True)
loss = outputs.loss loss = outputs.loss
@ -80,9 +76,7 @@ class LitModule(pl.LightningModule):
self.trainer.model.parameters(), lr=self.learning_rate self.trainer.model.parameters(), lr=self.learning_rate
) )
return optimizer return optimizer
optimizer = torch.optim.AdamW( optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=self.learning_rate)
self.trainer.model.parameters(), lr=self.learning_rate
)
return optimizer return optimizer
def configure_callbacks(self): def configure_callbacks(self):

View File

@ -25,16 +25,12 @@ def split_raw_dataset(
if 'validation' in raw_dataset: if 'validation' in raw_dataset:
train_dataset, val_dataset = raw_dataset['train'], raw_dataset['validation'] train_dataset, val_dataset = raw_dataset['train'], raw_dataset['validation']
else: else:
raw_dataset = raw_dataset['train'].train_test_split( raw_dataset = raw_dataset['train'].train_test_split(test_size=0.05, seed=args.seed)
test_size=0.05, seed=args.seed
)
train_dataset, val_dataset = raw_dataset['train'], raw_dataset['test'] train_dataset, val_dataset = raw_dataset['train'], raw_dataset['test']
return train_dataset, val_dataset return train_dataset, val_dataset
def process_dataset( def process_dataset(dataset: datasets.Dataset, tokenizer: PreTrainedTokenizer) -> datasets.Dataset:
dataset: datasets.Dataset, tokenizer: PreTrainedTokenizer
) -> datasets.Dataset:
def group_texts(examples: Dict[str, list], block_size: int = 512) -> BatchEncoding: def group_texts(examples: Dict[str, list], block_size: int = 512) -> BatchEncoding:
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]]) total_length = len(concatenated_examples[list(examples.keys())[0]])
@ -167,9 +163,7 @@ if __name__ == '__main__':
set_seed(args.seed) set_seed(args.seed)
# lightning module # lightning module
lit_module = LitModule( lit_module = LitModule(args.model_name, args.learning_rate, args.use_tril_attention_mask)
args.model_name, args.learning_rate, args.use_tril_attention_mask
)
# datasets # datasets
tokenizer = load_tokenizer(args.tokenizer_name_or_path) tokenizer = load_tokenizer(args.tokenizer_name_or_path)

View File

@ -35,22 +35,14 @@ def load_model(model_name_or_path: Union[str, os.PathLike]) -> PreTrainedModel:
model = custom_models.AutoModel.from_pretrained(model_name_or_path) model = custom_models.AutoModel.from_pretrained(model_name_or_path)
else: else:
try: try:
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True)
model_name_or_path, trust_remote_code=True
)
except ValueError: except ValueError:
model = AutoModel.from_pretrained( model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True)
model_name_or_path, trust_remote_code=True
)
return model return model
def load_tokenizer( def load_tokenizer(tokenizer_name_or_path: Union[str, os.PathLike]) -> PreTrainedTokenizer:
tokenizer_name_or_path: Union[str, os.PathLike] tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, padding_side='left', trust_remote_code=True)
) -> PreTrainedTokenizer:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name_or_path, padding_side='left', trust_remote_code=True
)
if tokenizer.pad_token_id is None: if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
return tokenizer return tokenizer