Delete kv cache of qwen.
This commit is contained in:
parent
0a78627e48
commit
3233616aac
|
@ -75,7 +75,6 @@ class QWenAttention(nn.Module):
|
|||
|
||||
self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
|
||||
self.softmax_in_fp32 = config.softmax_in_fp32 if hasattr(config, "softmax_in_fp32") else False
|
||||
self.use_cache_kernel = config.use_cache_kernel if hasattr(config, "use_cache_kernel") else False
|
||||
cache_dtype = torch.float
|
||||
self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype)
|
||||
self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype)
|
||||
|
@ -95,7 +94,6 @@ class QWenAttention(nn.Module):
|
|||
self,
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
||||
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
mixed_x_layer = self.c_attn(hidden_states)
|
||||
|
@ -112,12 +110,6 @@ class QWenAttention(nn.Module):
|
|||
query = apply_rotary_pos_emb(query, q_pos_emb)
|
||||
key = apply_rotary_pos_emb(key, k_pos_emb)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past[0], layer_past[1]
|
||||
|
||||
key = torch.cat((past_key, key), dim=1)
|
||||
value = torch.cat((past_value, value), dim=1)
|
||||
|
||||
present = (key, value)
|
||||
|
||||
key_size = key.size(1)
|
||||
|
@ -197,7 +189,6 @@ class QWenBlock(nn.Module):
|
|||
self,
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
||||
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
layernorm_output = self.ln_1(hidden_states)
|
||||
|
@ -205,7 +196,6 @@ class QWenBlock(nn.Module):
|
|||
attn_outputs = self.attn(
|
||||
layernorm_output,
|
||||
rotary_pos_emb_list,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
attn_output = attn_outputs[0]
|
||||
|
@ -227,7 +217,6 @@ class QWenPreTrainedModel(PreTrainedModel):
|
|||
is_parallelizable = False
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["QWenBlock"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
super().__init__(*inputs, **kwargs)
|
||||
|
@ -272,7 +261,6 @@ class QWenModel(QWenPreTrainedModel):
|
|||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
@ -289,9 +277,6 @@ class QWenModel(QWenPreTrainedModel):
|
|||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if past_key_values is None:
|
||||
past_key_values = tuple([None] * len(self.h))
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
attention_mask = attention_mask[:, None, None, :]
|
||||
|
@ -305,9 +290,6 @@ class QWenModel(QWenPreTrainedModel):
|
|||
hidden_states = inputs_embeds
|
||||
|
||||
kv_seq_len = hidden_states.size()[1]
|
||||
if past_key_values[0] is not None:
|
||||
# past key values[0][0] shape: bs * seq_len * head_num * dim
|
||||
kv_seq_len += past_key_values[0][0].shape[1]
|
||||
|
||||
if self.training or not self.use_dynamic_ntk:
|
||||
ntk_alpha_list = [1.0]
|
||||
|
@ -332,10 +314,9 @@ class QWenModel(QWenPreTrainedModel):
|
|||
|
||||
presents = ()
|
||||
all_hidden_states = None
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
for i, block in enumerate(self.h):
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
rotary_pos_emb_list=rotary_pos_emb_list,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
@ -344,9 +325,7 @@ class QWenModel(QWenPreTrainedModel):
|
|||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
hidden_states = hidden_states.view(output_shape)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states
|
||||
)
|
||||
return BaseModelOutputWithPast(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
|
||||
|
||||
|
||||
class QWenLMHeadModel(QWenPreTrainedModel):
|
||||
|
@ -357,23 +336,16 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.post_init()
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, **kwargs):
|
||||
if input_ids.size(0) == 1:
|
||||
attention_mask = None
|
||||
else:
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"past_key_values": past_key_values,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
)
|
||||
|
@ -382,7 +354,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
|
@ -390,7 +361,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
|
@ -418,7 +388,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=lm_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue