[feature] custom_models
This commit is contained in:
parent
5e6b747baf
commit
216bc4643c
|
@ -0,0 +1,52 @@
|
|||
import importlib
|
||||
from collections import OrderedDict
|
||||
|
||||
from transformers.models.auto import auto_factory, configuration_auto
|
||||
|
||||
|
||||
class _LazyAutoMapping(auto_factory._LazyAutoMapping):
|
||||
def _load_attr_from_module(self, model_type, attr):
|
||||
module_name = auto_factory.model_type_to_module_name(model_type)
|
||||
if module_name not in self._modules:
|
||||
self._modules[module_name] = importlib.import_module(
|
||||
f".{module_name}", "custom_models"
|
||||
)
|
||||
return auto_factory.getattribute_from_module(self._modules[module_name], attr)
|
||||
|
||||
|
||||
MODEL_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("gpt2", "GPT2Model"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("gpt2", "GPT2LMHeadModel"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
MODEL_MAPPING = _LazyAutoMapping(
|
||||
configuration_auto.CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES
|
||||
)
|
||||
|
||||
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(
|
||||
configuration_auto.CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||
)
|
||||
|
||||
|
||||
class AutoModel(auto_factory._BaseAutoModelClass):
|
||||
_model_mapping = MODEL_MAPPING
|
||||
|
||||
|
||||
AutoModel = auto_factory.auto_class_update(AutoModel)
|
||||
|
||||
|
||||
class AutoModelForCausalLM(auto_factory._BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING
|
||||
|
||||
|
||||
AutoModelForCausalLM = auto_factory.auto_class_update(AutoModelForCausalLM)
|
|
@ -0,0 +1,369 @@
|
|||
"""Override transformers GPT2 to support tril attention mask"""
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
import transformers
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
||||
from transformers.models.gpt2 import (
|
||||
_CHECKPOINT_FOR_DOC,
|
||||
_CONFIG_FOR_DOC,
|
||||
GPT2_INPUTS_DOCSTRING,
|
||||
)
|
||||
from transformers.utils import (
|
||||
ModelOutput,
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
)
|
||||
|
||||
|
||||
class GPT2Model(transformers.models.gpt2.GPT2Model):
|
||||
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=BaseModelOutputWithPastAndCrossAttentions,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time"
|
||||
)
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
batch_size = input_ids.shape[0]
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids.view(-1, input_shape[-1])
|
||||
|
||||
if past_key_values is None:
|
||||
past_length = 0
|
||||
past_key_values = tuple([None] * len(self.h))
|
||||
else:
|
||||
past_length = past_key_values[0][0].size(-2)
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(
|
||||
past_length,
|
||||
input_shape[-1] + past_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||
|
||||
# GPT2Attention mask.
|
||||
if attention_mask is not None:
|
||||
if batch_size <= 0:
|
||||
raise ValueError("batch_size has to be defined and > 0")
|
||||
if attention_mask.dim() == 2:
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
attention_mask = attention_mask[:, None, None, :]
|
||||
elif attention_mask.dim() == 3:
|
||||
attention_mask = attention_mask[:, None, ...]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"attention_mask.dim() is {attention_mask.dim()}, should be 2 or 3"
|
||||
)
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and the dtype's smallest value for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
||||
(
|
||||
encoder_batch_size,
|
||||
encoder_sequence_length,
|
||||
_,
|
||||
) = encoder_hidden_states.size()
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_attention_mask = None
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# head_mask has shape n_layer x batch x n_heads x N x N
|
||||
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
position_embeds = self.wpe(position_ids)
|
||||
hidden_states = inputs_embeds + position_embeds
|
||||
|
||||
if token_type_ids is not None:
|
||||
token_type_embeds = self.wte(token_type_ids)
|
||||
hidden_states = hidden_states + token_type_embeds
|
||||
|
||||
hidden_states = self.drop(hidden_states)
|
||||
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
presents = () if use_cache else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = (
|
||||
() if output_attentions and self.config.add_cross_attention else None
|
||||
)
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
# Model parallel
|
||||
if self.model_parallel:
|
||||
torch.cuda.set_device(hidden_states.device)
|
||||
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
||||
if layer_past is not None:
|
||||
layer_past = tuple(
|
||||
past_state.to(hidden_states.device) for past_state in layer_past
|
||||
)
|
||||
# Ensure that attention_mask is always on the same device as hidden_states
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(hidden_states.device)
|
||||
if isinstance(head_mask, torch.Tensor):
|
||||
head_mask = head_mask.to(hidden_states.device)
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, use_cache, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
None,
|
||||
attention_mask,
|
||||
head_mask[i],
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask[i],
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if use_cache is True:
|
||||
presents = presents + (outputs[1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (
|
||||
outputs[2 if use_cache else 1],
|
||||
)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (
|
||||
outputs[3 if use_cache else 2],
|
||||
)
|
||||
|
||||
# Model Parallel: If it's the last layer for that device, put things on the next device
|
||||
if self.model_parallel:
|
||||
for k, v in self.device_map.items():
|
||||
if i == v[-1] and "cuda:" + str(k) != self.last_device:
|
||||
hidden_states = hidden_states.to("cuda:" + str(k + 1))
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.view(output_shape)
|
||||
# Add last hidden state
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
presents,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=presents,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
class GPT2LMHeadModel(transformers.models.gpt2.GPT2LMHeadModel):
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
|
||||
):
|
||||
token_type_ids = kwargs.get("token_type_ids", None)
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"position_ids": position_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
def _update_model_kwargs_for_generation(
|
||||
self,
|
||||
outputs: ModelOutput,
|
||||
model_kwargs: Dict[str, Any],
|
||||
is_encoder_decoder: bool = False,
|
||||
standardize_cache_format: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
# update past_key_values
|
||||
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
|
||||
outputs, standardize_cache_format=standardize_cache_format
|
||||
)
|
||||
|
||||
# update token_type_ids with last value
|
||||
if "token_type_ids" in model_kwargs:
|
||||
token_type_ids = model_kwargs["token_type_ids"]
|
||||
model_kwargs["token_type_ids"] = torch.cat(
|
||||
[token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1
|
||||
)
|
||||
|
||||
# update position_ids
|
||||
if "position_ids" in model_kwargs:
|
||||
position_ids = model_kwargs["position_ids"]
|
||||
if model_kwargs["past_key_values"] is not None:
|
||||
model_kwargs["position_ids"] = (position_ids[:, -1] + 1).unsqueeze(-1)
|
||||
else:
|
||||
model_kwargs["position_ids"] = torch.cat(
|
||||
[position_ids, (position_ids[:, -1] + 1).unsqueeze(-1)], dim=-1
|
||||
)
|
||||
|
||||
if not is_encoder_decoder:
|
||||
# update attention mask
|
||||
if "attention_mask" in model_kwargs:
|
||||
attention_mask = model_kwargs["attention_mask"]
|
||||
if attention_mask.dim() == 2:
|
||||
model_kwargs["attention_mask"] = torch.cat(
|
||||
[
|
||||
attention_mask,
|
||||
attention_mask.new_ones((attention_mask.shape[0], 1)),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
elif attention_mask.dim() == 3:
|
||||
attention_mask = attention_mask[:, -1, :]
|
||||
attention_mask = torch.cat(
|
||||
[
|
||||
attention_mask,
|
||||
attention_mask.new_ones((attention_mask.shape[0], 1)),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
model_kwargs["attention_mask"] = attention_mask
|
||||
else:
|
||||
raise ValueError(
|
||||
f"attention_mask.dim() is {attention_mask.dim()}, should be 2 or 3"
|
||||
)
|
||||
else:
|
||||
# update decoder attention mask
|
||||
if "decoder_attention_mask" in model_kwargs:
|
||||
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
|
||||
model_kwargs["decoder_attention_mask"] = torch.cat(
|
||||
[
|
||||
decoder_attention_mask,
|
||||
decoder_attention_mask.new_ones(
|
||||
(decoder_attention_mask.shape[0], 1)
|
||||
),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
return model_kwargs
|
35
utils.py
35
utils.py
|
@ -10,23 +10,38 @@ from transformers import (
|
|||
PreTrainedTokenizer,
|
||||
)
|
||||
|
||||
import custom_models
|
||||
|
||||
|
||||
def init_model(model_name: Union[str, os.PathLike]) -> PreTrainedModel:
|
||||
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
try:
|
||||
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
|
||||
except ValueError:
|
||||
model = AutoModel.from_config(config, trust_remote_code=True)
|
||||
|
||||
if model_name in custom_models.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||
model = custom_models.AutoModelForCausalLM.from_config(config)
|
||||
elif model_name in custom_models.MODEL_MAPPING_NAMES:
|
||||
model = custom_models.AutoModel.from_config(config)
|
||||
else:
|
||||
try:
|
||||
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
|
||||
except ValueError:
|
||||
model = AutoModel.from_config(config, trust_remote_code=True)
|
||||
return model
|
||||
|
||||
|
||||
def load_model(model_name_or_path: Union[str, os.PathLike]) -> PreTrainedModel:
|
||||
try:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name_or_path, trust_remote_code=True
|
||||
)
|
||||
except ValueError:
|
||||
model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True)
|
||||
if model_name_or_path in custom_models.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||
model = custom_models.AutoModelForCausalLM.from_pretrained(model_name_or_path)
|
||||
elif model_name_or_path in custom_models.MODEL_MAPPING_NAMES:
|
||||
model = custom_models.AutoModel.from_pretrained(model_name_or_path)
|
||||
else:
|
||||
try:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name_or_path, trust_remote_code=True
|
||||
)
|
||||
except ValueError:
|
||||
model = AutoModel.from_pretrained(
|
||||
model_name_or_path, trust_remote_code=True
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue