Refine model of qwen.

This commit is contained in:
Colin 2024-01-07 22:36:55 +08:00
parent 4c0991a409
commit 94ecf0f561
1 changed files with 16 additions and 116 deletions

View File

@ -29,14 +29,8 @@ from transformers.modeling_outputs import (
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging from transformers.utils import logging
try:
from einops import rearrange
except ImportError:
rearrange = None
from torch import nn from torch import nn
SUPPORT_CUDA = torch.cuda.is_available()
from configuration_qwen import QWenConfig from configuration_qwen import QWenConfig
from qwen_generation_utils import ( from qwen_generation_utils import (
HistoryType, HistoryType,
@ -48,8 +42,6 @@ from qwen_generation_utils import (
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_SENTINEL = object()
class QWenAttention(nn.Module): class QWenAttention(nn.Module):
def __init__(self, config): def __init__(self, config):
@ -104,38 +96,20 @@ class QWenAttention(nn.Module):
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None, rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None, layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
): ):
mixed_x_layer = self.c_attn(hidden_states) mixed_x_layer = self.c_attn(hidden_states)
query, key, value = mixed_x_layer.split(self.split_size, dim=2) query, key, value = mixed_x_layer.split(self.split_size, dim=2)
query = self._split_heads(query, self.num_heads, self.head_dim) query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim) key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim) value = self._split_heads(value, self.num_heads, self.head_dim)
if rotary_pos_emb_list is not None:
cur_len = query.shape[1]
if len(rotary_pos_emb_list) == 1:
rotary_pos_emb = rotary_pos_emb_list[0] rotary_pos_emb = rotary_pos_emb_list[0]
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] rotary_pos_emb = [i[:, -query.shape[1] :, :, :] for i in rotary_pos_emb]
rotary_pos_emb = (rotary_pos_emb,) * 2 rotary_pos_emb = (rotary_pos_emb,) * 2
q_pos_emb, k_pos_emb = rotary_pos_emb q_pos_emb, k_pos_emb = rotary_pos_emb
# Slice the pos emb for current inference # Slice the pos emb for current inference
query = apply_rotary_pos_emb(query, q_pos_emb) query = apply_rotary_pos_emb(query, q_pos_emb)
key = apply_rotary_pos_emb(key, k_pos_emb) key = apply_rotary_pos_emb(key, k_pos_emb)
else:
query_list = []
key_list = []
for i, rotary_pos_emb in enumerate(rotary_pos_emb_list):
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
rotary_pos_emb = (rotary_pos_emb,) * 2
q_pos_emb, k_pos_emb = rotary_pos_emb
# Slice the pos emb for current inference
query_list += [apply_rotary_pos_emb(query[i : i + 1, :, :], q_pos_emb)]
key_list += [apply_rotary_pos_emb(key[i : i + 1, :, :], k_pos_emb)]
query = torch.cat(query_list, dim=0)
key = torch.cat(key_list, dim=0)
if layer_past is not None: if layer_past is not None:
past_key, past_value = layer_past[0], layer_past[1] past_key, past_value = layer_past[0], layer_past[1]
@ -143,10 +117,7 @@ class QWenAttention(nn.Module):
key = torch.cat((past_key, key), dim=1) key = torch.cat((past_key, key), dim=1)
value = torch.cat((past_value, value), dim=1) value = torch.cat((past_value, value), dim=1)
if use_cache:
present = (key, value) present = (key, value)
else:
present = None
key_size = key.size(1) key_size = key.size(1)
if key_size > self.seq_length and self.use_logn_attn and not self.training: if key_size > self.seq_length and self.use_logn_attn and not self.training:
@ -186,9 +157,9 @@ class QWenAttention(nn.Module):
class QWenMLP(nn.Module): class QWenMLP(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.w1 = nn.Linear(config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias)
self.w2 = nn.Linear(config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias)
ff_dim_in = config.intermediate_size // 2 ff_dim_in = config.intermediate_size // 2
self.w1 = nn.Linear(config.hidden_size, ff_dim_in, bias=not config.no_bias)
self.w2 = nn.Linear(config.hidden_size, ff_dim_in, bias=not config.no_bias)
self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias) self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias)
def forward(self, hidden_states): def forward(self, hidden_states):
@ -222,11 +193,6 @@ class QWenBlock(nn.Module):
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None, rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None, layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
): ):
layernorm_output = self.ln_1(hidden_states) layernorm_output = self.ln_1(hidden_states)
@ -235,7 +201,6 @@ class QWenBlock(nn.Module):
rotary_pos_emb_list, rotary_pos_emb_list,
layer_past=layer_past, layer_past=layer_past,
attention_mask=attention_mask, attention_mask=attention_mask,
use_cache=use_cache,
) )
attn_output = attn_outputs[0] attn_output = attn_outputs[0]
@ -251,10 +216,7 @@ class QWenBlock(nn.Module):
mlp_output = self.mlp(layernorm_output) mlp_output = self.mlp(layernorm_output)
hidden_states = residual + mlp_output hidden_states = residual + mlp_output
if use_cache:
outputs = (hidden_states,) + outputs outputs = (hidden_states,) + outputs
else:
outputs = (hidden_states,) + outputs[1:]
return outputs return outputs
@ -314,14 +276,7 @@ class QWenModel(QWenPreTrainedModel):
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
): ):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
use_cache = use_cache if use_cache is not None else self.config.use_cache
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
@ -343,7 +298,6 @@ class QWenModel(QWenPreTrainedModel):
attention_mask = attention_mask.to(dtype=self.dtype) attention_mask = attention_mask.to(dtype=self.dtype)
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
encoder_attention_mask = None
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
if inputs_embeds is None: if inputs_embeds is None:
@ -376,8 +330,7 @@ class QWenModel(QWenPreTrainedModel):
hidden_states = self.drop(hidden_states) hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),) output_shape = input_shape + (hidden_states.size(-1),)
presents = () if use_cache else None presents = ()
all_self_attentions = () if output_attentions else None
all_hidden_states = None all_hidden_states = None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
outputs = block( outputs = block(
@ -385,27 +338,14 @@ class QWenModel(QWenPreTrainedModel):
layer_past=layer_past, layer_past=layer_past,
rotary_pos_emb_list=rotary_pos_emb_list, rotary_pos_emb_list=rotary_pos_emb_list,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
) )
hidden_states = outputs[0] hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],) presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape) hidden_states = hidden_states.view(output_shape)
return BaseModelOutputWithPast( return BaseModelOutputWithPast(
last_hidden_state=hidden_states, last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
) )
@ -434,7 +374,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
model_inputs.update( model_inputs.update(
{ {
"past_key_values": past_key_values, "past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask, "attention_mask": attention_mask,
} }
) )
@ -447,11 +386,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]: ) -> Union[Tuple, CausalLMOutputWithPast]:
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
input_ids, input_ids,
@ -459,10 +394,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
) )
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
@ -498,19 +429,15 @@ class QWenLMHeadModel(QWenPreTrainedModel):
query: str, query: str,
history: Optional[HistoryType], history: Optional[HistoryType],
system: str = "You are a helpful assistant.", system: str = "You are a helpful assistant.",
stream: Optional[bool] = _SENTINEL,
stop_words_ids: Optional[List[List[int]]] = None, stop_words_ids: Optional[List[List[int]]] = None,
generation_config: Optional[GenerationConfig] = None, generation_config: Optional[GenerationConfig] = None,
**kwargs, **kwargs,
) -> Tuple[str, HistoryType]: ) -> Tuple[str, HistoryType]:
generation_config = generation_config if generation_config is not None else self.generation_config generation_config = generation_config if generation_config is not None else self.generation_config
assert stream is _SENTINEL
assert generation_config.chat_format == "chatml"
if history is None: if history is None:
history = [] history = []
else: else:
# make a copy of the user's input such that is is left untouched
history = copy.deepcopy(history) history = copy.deepcopy(history)
if stop_words_ids is None: if stop_words_ids is None:
@ -536,7 +463,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
generation_config=generation_config, generation_config=generation_config,
**kwargs, **kwargs,
) )
response = decode_tokens( response = decode_tokens(
outputs[0], outputs[0],
tokenizer, tokenizer,
@ -546,13 +472,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
verbose=False, verbose=False,
errors="replace", errors="replace",
) )
# as history is a copy of the user inputs,
# we can always return the new turn to the user.
# separating input history and output history also enables the user
# to implement more complex history management
history.append((query, response)) history.append((query, response))
return response, history return response, history
def generate( def generate(
@ -562,7 +482,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
logits_processor: Optional[LogitsProcessorList] = None, logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None, assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None, streamer: Optional["BaseStreamer"] = None,
**kwargs, **kwargs,
@ -592,7 +511,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
logits_processor=logits_processor, logits_processor=logits_processor,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
synced_gpus=synced_gpus,
assistant_model=assistant_model, assistant_model=assistant_model,
streamer=streamer, streamer=streamer,
**kwargs, **kwargs,
@ -605,7 +523,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
logits_processor: Optional[LogitsProcessorList] = None, logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None, assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None, streamer: Optional["BaseStreamer"] = None,
negative_prompt_ids: Optional[torch.Tensor] = None, negative_prompt_ids: Optional[torch.Tensor] = None,
@ -637,21 +554,10 @@ class QWenLMHeadModel(QWenPreTrainedModel):
generation_config.pad_token_id = eos_token_id generation_config.pad_token_id = eos_token_id
# 3. Define model inputs # 3. Define model inputs
# inputs_tensor has to be defined
# model_input_name is defined if model-specific keyword input is passed
# otherwise model_input_name is None
# all model-specific keyword inputs are removed from `model_kwargs`
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
inputs, generation_config.bos_token_id, model_kwargs inputs, generation_config.bos_token_id, model_kwargs
) )
# 4. Define other model kwargs # 4. Define other model kwargs
model_kwargs["output_attentions"] = generation_config.output_attentions
# decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
# generating the first new token or not, and we only want to use the embeddings for the first new token)
if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
model_kwargs["use_cache"] = True
else:
model_kwargs["use_cache"] = generation_config.use_cache
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
requires_attention_mask = "encoder_outputs" not in model_kwargs requires_attention_mask = "encoder_outputs" not in model_kwargs
@ -713,7 +619,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
pad_token_id=generation_config.pad_token_id, pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id, eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores, output_scores=generation_config.output_scores,
synced_gpus=synced_gpus,
streamer=streamer, streamer=streamer,
**model_kwargs, **model_kwargs,
) )
@ -727,9 +632,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
max_length: Optional[int] = None, max_length: Optional[int] = None,
pad_token_id: Optional[int] = None, pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None, eos_token_id: Optional[Union[int, List[int]]] = None,
output_attentions: Optional[bool] = None,
output_scores: Optional[bool] = None, output_scores: Optional[bool] = None,
synced_gpus: bool = False,
streamer: Optional["BaseStreamer"] = None, streamer: Optional["BaseStreamer"] = None,
**model_kwargs, **model_kwargs,
): ):
@ -744,9 +647,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
eos_token_id = [eos_token_id] eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
)
# init attention / hidden states / scores tuples # init attention / hidden states / scores tuples
scores = None scores = None
@ -754,14 +654,14 @@ class QWenLMHeadModel(QWenPreTrainedModel):
# keep track of which sequences are already finished # keep track of which sequences are already finished
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
this_peer_finished = False # used by synced_gpus only this_peer_finished = False
# auto-regressive generation # auto-regressive generation
while True: while True:
# prepare model inputs # prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token # forward pass to get next token
outputs = self(**model_inputs, output_attentions=output_attentions) outputs = self(**model_inputs)
next_token_logits = outputs.logits[:, -1, :] next_token_logits = outputs.logits[:, -1, :]