diff --git a/custom_models/__init__.py b/custom_models/__init__.py index cb5177e..8401ed2 100644 --- a/custom_models/__init__.py +++ b/custom_models/__init__.py @@ -3,14 +3,22 @@ from collections import OrderedDict from transformers.models.auto import auto_factory, configuration_auto +CONFIG_MAPPING_NAMES = OrderedDict([]) + + +def register_custom_configs(): + for model_type, map_name in CONFIG_MAPPING_NAMES.items(): + module_name = configuration_auto.model_type_to_module_name(model_type) + module = importlib.import_module(f".{module_name}", "custom_models") + mapping = getattr(module, map_name) + configuration_auto.AutoConfig.register(model_type, mapping) + class _LazyAutoMapping(auto_factory._LazyAutoMapping): def _load_attr_from_module(self, model_type, attr): module_name = auto_factory.model_type_to_module_name(model_type) if module_name not in self._modules: - self._modules[module_name] = importlib.import_module( - f".{module_name}", "custom_models" - ) + self._modules[module_name] = importlib.import_module(f".{module_name}", "custom_models") return auto_factory.getattribute_from_module(self._modules[module_name], attr) @@ -29,12 +37,12 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( MODEL_MAPPING = _LazyAutoMapping( - configuration_auto.CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES + {**CONFIG_MAPPING_NAMES, **configuration_auto.CONFIG_MAPPING_NAMES}, MODEL_MAPPING_NAMES ) MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping( - configuration_auto.CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES + {**CONFIG_MAPPING_NAMES, **configuration_auto.CONFIG_MAPPING_NAMES}, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES ) diff --git a/custom_models/gpt2/modeling_gpt2.py b/custom_models/gpt2/modeling_gpt2.py index 98a5d4e..607d590 100644 --- a/custom_models/gpt2/modeling_gpt2.py +++ b/custom_models/gpt2/modeling_gpt2.py @@ -43,25 +43,15 @@ class GPT2Model(transformers.models.gpt2.GPT2Model): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time" - ) + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) @@ -107,9 +97,7 @@ class GPT2Model(transformers.models.gpt2.GPT2Model): elif attention_mask.dim() == 3: attention_mask = attention_mask[:, None, ...] else: - raise ValueError( - f"attention_mask.dim() is {attention_mask.dim()}, should be 2 or 3" - ) + raise ValueError(f"attention_mask.dim() is {attention_mask.dim()}, should be 2 or 3") # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for @@ -162,9 +150,7 @@ class GPT2Model(transformers.models.gpt2.GPT2Model): presents = () if use_cache else None all_self_attentions = () if output_attentions else None - all_cross_attentions = ( - () if output_attentions and self.config.add_cross_attention else None - ) + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_hidden_states = () if output_hidden_states else None for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): # Model parallel @@ -172,9 +158,7 @@ class GPT2Model(transformers.models.gpt2.GPT2Model): torch.cuda.set_device(hidden_states.device) # Ensure layer_past is on same device as hidden_states (might not be correct) if layer_past is not None: - layer_past = tuple( - past_state.to(hidden_states.device) for past_state in layer_past - ) + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) # Ensure that attention_mask is always on the same device as hidden_states if attention_mask is not None: attention_mask = attention_mask.to(hidden_states.device) @@ -218,13 +202,9 @@ class GPT2Model(transformers.models.gpt2.GPT2Model): presents = presents + (outputs[1],) if output_attentions: - all_self_attentions = all_self_attentions + ( - outputs[2 if use_cache else 1], - ) + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + ( - outputs[3 if use_cache else 2], - ) + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) # Model Parallel: If it's the last layer for that device, put things on the next device if self.model_parallel: @@ -274,9 +254,7 @@ class GPT2LMHeadModel(transformers.models.gpt2.GPT2LMHeadModel): # Initialize weights and apply final processing self.post_init() - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs - ): + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) # only last token for inputs_ids if past is defined in kwargs if past_key_values: @@ -326,9 +304,7 @@ class GPT2LMHeadModel(transformers.models.gpt2.GPT2LMHeadModel): # update token_type_ids with last value if "token_type_ids" in model_kwargs: token_type_ids = model_kwargs["token_type_ids"] - model_kwargs["token_type_ids"] = torch.cat( - [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1 - ) + model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) # update position_ids if "position_ids" in model_kwargs: @@ -363,9 +339,7 @@ class GPT2LMHeadModel(transformers.models.gpt2.GPT2LMHeadModel): ) model_kwargs["attention_mask"] = attention_mask else: - raise ValueError( - f"attention_mask.dim() is {attention_mask.dim()}, should be 2 or 3" - ) + raise ValueError(f"attention_mask.dim() is {attention_mask.dim()}, should be 2 or 3") else: # update decoder attention mask if "decoder_attention_mask" in model_kwargs: @@ -373,9 +347,7 @@ class GPT2LMHeadModel(transformers.models.gpt2.GPT2LMHeadModel): model_kwargs["decoder_attention_mask"] = torch.cat( [ decoder_attention_mask, - decoder_attention_mask.new_ones( - (decoder_attention_mask.shape[0], 1) - ), + decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1)), ], dim=-1, ) diff --git a/generate.py b/generate.py index 378214e..b9c6da0 100644 --- a/generate.py +++ b/generate.py @@ -13,15 +13,11 @@ def eval_prompts( prompts: List[str], use_tril_attention_mask: bool = False, ) -> List[str]: - inputs = tokenizer( - prompts, padding=True, return_tensors='pt', return_attention_mask=True - ) + inputs = tokenizer(prompts, padding=True, return_tensors='pt', return_attention_mask=True) inputs['position_ids'] = inputs.attention_mask.cumsum(-1) - 1 inputs['position_ids'].masked_fill_(inputs.attention_mask == 0, 1) if use_tril_attention_mask: - inputs['attention_mask'] = ( - inputs.attention_mask.unsqueeze(1) * inputs.attention_mask.unsqueeze(2) - ).tril() + inputs['attention_mask'] = (inputs.attention_mask.unsqueeze(1) * inputs.attention_mask.unsqueeze(2)).tril() inputs = inputs.to(model.device) with torch.inference_mode(): output_ids = model.generate( @@ -32,9 +28,7 @@ def eval_prompts( eos_token_id=tokenizer.eos_token_id, early_stopping=True, ) - completes = tokenizer.batch_decode( - output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False - ) + completes = tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) return completes @@ -81,9 +75,7 @@ if __name__ == '__main__': "这是一个最好的时代,这是一个最坏的时代。", "这是一个最好的时代,这是一个最坏的", ] - completes = eval_prompts( - model, tokenizer, prompts, use_tril_attention_mask=args.use_tril_attention_mask - ) + completes = eval_prompts(model, tokenizer, prompts, use_tril_attention_mask=args.use_tril_attention_mask) for prompt, complete in zip(prompts, completes): print("[p]", prompt) diff --git a/lit_export.py b/lit_export.py index 07cac08..25f760d 100644 --- a/lit_export.py +++ b/lit_export.py @@ -26,8 +26,6 @@ if __name__ == '__main__': checkpoint_file_path = next(lightning_logs_dir_path.glob("checkpoints/*.ckpt")) - lit_module = LitModule.load_from_checkpoint( - checkpoint_file_path, map_location='cpu' - ) + lit_module = LitModule.load_from_checkpoint(checkpoint_file_path, map_location='cpu') model: PreTrainedModel = lit_module.__core_module__ model.save_pretrained(exports_dir_path) diff --git a/lit_module.py b/lit_module.py index c1c5393..50f0e73 100644 --- a/lit_module.py +++ b/lit_module.py @@ -27,9 +27,7 @@ class LitModule(pl.LightningModule): ) @cache - def get_batch_tril_matrix( - self, block_size: int, batch_size: Optional[int] = None - ) -> torch.Tensor: + def get_batch_tril_matrix(self, block_size: int, batch_size: Optional[int] = None) -> torch.Tensor: matrix = torch.ones(block_size, block_size).tril() if batch_size is not None: matrix = matrix.repeat(batch_size, 1, 1) @@ -42,9 +40,7 @@ class LitModule(pl.LightningModule): def training_step(self, batch: Dict[str, torch.Tensor], batch_idx): batch_size, block_size = batch['input_ids'].shape if self.use_tril_attention_mask: - batch['attention_mask'] = self.get_batch_tril_matrix( - block_size, batch_size=batch_size - ).to(self.device) + batch['attention_mask'] = self.get_batch_tril_matrix(block_size, batch_size=batch_size).to(self.device) outputs = self.llm(**batch, return_dict=True) loss = outputs.loss @@ -80,9 +76,7 @@ class LitModule(pl.LightningModule): self.trainer.model.parameters(), lr=self.learning_rate ) return optimizer - optimizer = torch.optim.AdamW( - self.trainer.model.parameters(), lr=self.learning_rate - ) + optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=self.learning_rate) return optimizer def configure_callbacks(self): diff --git a/lit_train.py b/lit_train.py index 560a2e7..9a5968c 100644 --- a/lit_train.py +++ b/lit_train.py @@ -25,16 +25,12 @@ def split_raw_dataset( if 'validation' in raw_dataset: train_dataset, val_dataset = raw_dataset['train'], raw_dataset['validation'] else: - raw_dataset = raw_dataset['train'].train_test_split( - test_size=0.05, seed=args.seed - ) + raw_dataset = raw_dataset['train'].train_test_split(test_size=0.05, seed=args.seed) train_dataset, val_dataset = raw_dataset['train'], raw_dataset['test'] return train_dataset, val_dataset -def process_dataset( - dataset: datasets.Dataset, tokenizer: PreTrainedTokenizer -) -> datasets.Dataset: +def process_dataset(dataset: datasets.Dataset, tokenizer: PreTrainedTokenizer) -> datasets.Dataset: def group_texts(examples: Dict[str, list], block_size: int = 512) -> BatchEncoding: concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} total_length = len(concatenated_examples[list(examples.keys())[0]]) @@ -167,9 +163,7 @@ if __name__ == '__main__': set_seed(args.seed) # lightning module - lit_module = LitModule( - args.model_name, args.learning_rate, args.use_tril_attention_mask - ) + lit_module = LitModule(args.model_name, args.learning_rate, args.use_tril_attention_mask) # datasets tokenizer = load_tokenizer(args.tokenizer_name_or_path) diff --git a/utils.py b/utils.py index 5e26fcd..65d3acc 100644 --- a/utils.py +++ b/utils.py @@ -12,9 +12,11 @@ from transformers import ( import custom_models +custom_models.register_custom_configs() -def init_model(model_name: Union[str, os.PathLike]) -> PreTrainedModel: - config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + +def init_model(model_name: str) -> PreTrainedModel: + config = AutoConfig.for_model(model_type=model_name) if model_name in custom_models.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: model = custom_models.AutoModelForCausalLM.from_config(config) @@ -35,22 +37,14 @@ def load_model(model_name_or_path: Union[str, os.PathLike]) -> PreTrainedModel: model = custom_models.AutoModel.from_pretrained(model_name_or_path) else: try: - model = AutoModelForCausalLM.from_pretrained( - model_name_or_path, trust_remote_code=True - ) + model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True) except ValueError: - model = AutoModel.from_pretrained( - model_name_or_path, trust_remote_code=True - ) + model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True) return model -def load_tokenizer( - tokenizer_name_or_path: Union[str, os.PathLike] -) -> PreTrainedTokenizer: - tokenizer = AutoTokenizer.from_pretrained( - tokenizer_name_or_path, padding_side='left', trust_remote_code=True - ) +def load_tokenizer(tokenizer_name_or_path: Union[str, os.PathLike]) -> PreTrainedTokenizer: + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, padding_side='left', trust_remote_code=True) if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token return tokenizer