Delete kv cache of qwen.

This commit is contained in:
Colin 2024-01-18 20:23:21 +08:00
parent 0a78627e48
commit 3233616aac
1 changed files with 4 additions and 35 deletions

View File

@ -75,7 +75,6 @@ class QWenAttention(nn.Module):
self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
self.softmax_in_fp32 = config.softmax_in_fp32 if hasattr(config, "softmax_in_fp32") else False
self.use_cache_kernel = config.use_cache_kernel if hasattr(config, "use_cache_kernel") else False
cache_dtype = torch.float
self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype)
self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype)
@ -95,7 +94,6 @@ class QWenAttention(nn.Module):
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
):
mixed_x_layer = self.c_attn(hidden_states)
@ -112,12 +110,6 @@ class QWenAttention(nn.Module):
query = apply_rotary_pos_emb(query, q_pos_emb)
key = apply_rotary_pos_emb(key, k_pos_emb)
if layer_past is not None:
past_key, past_value = layer_past[0], layer_past[1]
key = torch.cat((past_key, key), dim=1)
value = torch.cat((past_value, value), dim=1)
present = (key, value)
key_size = key.size(1)
@ -197,7 +189,6 @@ class QWenBlock(nn.Module):
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
):
layernorm_output = self.ln_1(hidden_states)
@ -205,7 +196,6 @@ class QWenBlock(nn.Module):
attn_outputs = self.attn(
layernorm_output,
rotary_pos_emb_list,
layer_past=layer_past,
attention_mask=attention_mask,
)
attn_output = attn_outputs[0]
@ -227,7 +217,6 @@ class QWenPreTrainedModel(PreTrainedModel):
is_parallelizable = False
supports_gradient_checkpointing = True
_no_split_modules = ["QWenBlock"]
_skip_keys_device_placement = "past_key_values"
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
@ -272,7 +261,6 @@ class QWenModel(QWenPreTrainedModel):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
@ -289,9 +277,6 @@ class QWenModel(QWenPreTrainedModel):
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if past_key_values is None:
past_key_values = tuple([None] * len(self.h))
if attention_mask is not None:
attention_mask = attention_mask.view(batch_size, -1)
attention_mask = attention_mask[:, None, None, :]
@ -305,9 +290,6 @@ class QWenModel(QWenPreTrainedModel):
hidden_states = inputs_embeds
kv_seq_len = hidden_states.size()[1]
if past_key_values[0] is not None:
# past key values[0][0] shape: bs * seq_len * head_num * dim
kv_seq_len += past_key_values[0][0].shape[1]
if self.training or not self.use_dynamic_ntk:
ntk_alpha_list = [1.0]
@ -332,10 +314,9 @@ class QWenModel(QWenPreTrainedModel):
presents = ()
all_hidden_states = None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
for i, block in enumerate(self.h):
outputs = block(
hidden_states,
layer_past=layer_past,
rotary_pos_emb_list=rotary_pos_emb_list,
attention_mask=attention_mask,
)
@ -344,9 +325,7 @@ class QWenModel(QWenPreTrainedModel):
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states
)
return BaseModelOutputWithPast(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
class QWenLMHeadModel(QWenPreTrainedModel):
@ -357,23 +336,16 @@ class QWenLMHeadModel(QWenPreTrainedModel):
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, **kwargs):
if input_ids.size(0) == 1:
attention_mask = None
else:
attention_mask = kwargs.get("attention_mask", None)
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"attention_mask": attention_mask,
}
)
@ -382,7 +354,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
@ -390,7 +361,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
) -> Union[Tuple, CausalLMOutputWithPast]:
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
@ -418,7 +388,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
return CausalLMOutputWithPast(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)