diff --git a/custom_models/__init__.py b/custom_models/__init__.py new file mode 100644 index 0000000..cb5177e --- /dev/null +++ b/custom_models/__init__.py @@ -0,0 +1,52 @@ +import importlib +from collections import OrderedDict + +from transformers.models.auto import auto_factory, configuration_auto + + +class _LazyAutoMapping(auto_factory._LazyAutoMapping): + def _load_attr_from_module(self, model_type, attr): + module_name = auto_factory.model_type_to_module_name(model_type) + if module_name not in self._modules: + self._modules[module_name] = importlib.import_module( + f".{module_name}", "custom_models" + ) + return auto_factory.getattribute_from_module(self._modules[module_name], attr) + + +MODEL_MAPPING_NAMES = OrderedDict( + [ + ("gpt2", "GPT2Model"), + ] +) + + +MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( + [ + ("gpt2", "GPT2LMHeadModel"), + ] +) + + +MODEL_MAPPING = _LazyAutoMapping( + configuration_auto.CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES +) + + +MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping( + configuration_auto.CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES +) + + +class AutoModel(auto_factory._BaseAutoModelClass): + _model_mapping = MODEL_MAPPING + + +AutoModel = auto_factory.auto_class_update(AutoModel) + + +class AutoModelForCausalLM(auto_factory._BaseAutoModelClass): + _model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING + + +AutoModelForCausalLM = auto_factory.auto_class_update(AutoModelForCausalLM) diff --git a/custom_models/gpt2/__init__.py b/custom_models/gpt2/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/custom_models/gpt2/modeling_gpt2.py b/custom_models/gpt2/modeling_gpt2.py new file mode 100644 index 0000000..f3d7d00 --- /dev/null +++ b/custom_models/gpt2/modeling_gpt2.py @@ -0,0 +1,369 @@ +"""Override transformers GPT2 to support tril attention mask""" + +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +import transformers +from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions +from transformers.models.gpt2 import ( + _CHECKPOINT_FOR_DOC, + _CONFIG_FOR_DOC, + GPT2_INPUTS_DOCSTRING, +) +from transformers.utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings_to_model_forward, +) + + +class GPT2Model(transformers.models.gpt2.GPT2Model): + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange( + past_length, + input_shape[-1] + past_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # GPT2Attention mask. + if attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + if attention_mask.dim() == 2: + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + elif attention_mask.dim() == 3: + attention_mask = attention_mask[:, None, ...] + else: + raise ValueError( + f"attention_mask.dim() is {attention_mask.dim()}, should be 2 or 3" + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + () if output_attentions and self.config.add_cross_attention else None + ) + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple( + past_state.to(hidden_states.device) for past_state in layer_past + ) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + ( + outputs[2 if use_cache else 1], + ) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + ( + outputs[3 if use_cache else 2], + ) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class GPT2LMHeadModel(transformers.models.gpt2.GPT2LMHeadModel): + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs + ): + token_type_ids = kwargs.get("token_type_ids", None) + # only last token for inputs_ids if past is defined in kwargs + if past_key_values: + input_ids = input_ids[:, -1].unsqueeze(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + ) + return model_inputs + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + standardize_cache_format: bool = False, + ) -> Dict[str, Any]: + # update past_key_values + model_kwargs["past_key_values"] = self._extract_past_from_model_output( + outputs, standardize_cache_format=standardize_cache_format + ) + + # update token_type_ids with last value + if "token_type_ids" in model_kwargs: + token_type_ids = model_kwargs["token_type_ids"] + model_kwargs["token_type_ids"] = torch.cat( + [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1 + ) + + # update position_ids + if "position_ids" in model_kwargs: + position_ids = model_kwargs["position_ids"] + if model_kwargs["past_key_values"] is not None: + model_kwargs["position_ids"] = (position_ids[:, -1] + 1).unsqueeze(-1) + else: + model_kwargs["position_ids"] = torch.cat( + [position_ids, (position_ids[:, -1] + 1).unsqueeze(-1)], dim=-1 + ) + + if not is_encoder_decoder: + # update attention mask + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + if attention_mask.dim() == 2: + model_kwargs["attention_mask"] = torch.cat( + [ + attention_mask, + attention_mask.new_ones((attention_mask.shape[0], 1)), + ], + dim=-1, + ) + elif attention_mask.dim() == 3: + attention_mask = attention_mask[:, -1, :] + attention_mask = torch.cat( + [ + attention_mask, + attention_mask.new_ones((attention_mask.shape[0], 1)), + ], + dim=-1, + ) + model_kwargs["attention_mask"] = attention_mask + else: + raise ValueError( + f"attention_mask.dim() is {attention_mask.dim()}, should be 2 or 3" + ) + else: + # update decoder attention mask + if "decoder_attention_mask" in model_kwargs: + decoder_attention_mask = model_kwargs["decoder_attention_mask"] + model_kwargs["decoder_attention_mask"] = torch.cat( + [ + decoder_attention_mask, + decoder_attention_mask.new_ones( + (decoder_attention_mask.shape[0], 1) + ), + ], + dim=-1, + ) + + return model_kwargs diff --git a/utils.py b/utils.py index 5675bd4..5e26fcd 100644 --- a/utils.py +++ b/utils.py @@ -10,23 +10,38 @@ from transformers import ( PreTrainedTokenizer, ) +import custom_models + def init_model(model_name: Union[str, os.PathLike]) -> PreTrainedModel: config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) - try: - model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) - except ValueError: - model = AutoModel.from_config(config, trust_remote_code=True) + + if model_name in custom_models.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: + model = custom_models.AutoModelForCausalLM.from_config(config) + elif model_name in custom_models.MODEL_MAPPING_NAMES: + model = custom_models.AutoModel.from_config(config) + else: + try: + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + except ValueError: + model = AutoModel.from_config(config, trust_remote_code=True) return model def load_model(model_name_or_path: Union[str, os.PathLike]) -> PreTrainedModel: - try: - model = AutoModelForCausalLM.from_pretrained( - model_name_or_path, trust_remote_code=True - ) - except ValueError: - model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True) + if model_name_or_path in custom_models.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: + model = custom_models.AutoModelForCausalLM.from_pretrained(model_name_or_path) + elif model_name_or_path in custom_models.MODEL_MAPPING_NAMES: + model = custom_models.AutoModel.from_pretrained(model_name_or_path) + else: + try: + model = AutoModelForCausalLM.from_pretrained( + model_name_or_path, trust_remote_code=True + ) + except ValueError: + model = AutoModel.from_pretrained( + model_name_or_path, trust_remote_code=True + ) return model