Delete kv cache of qwen.

This commit is contained in:
Colin 2024-01-18 20:23:21 +08:00
parent 0a78627e48
commit 3233616aac
1 changed files with 4 additions and 35 deletions

View File

@ -75,7 +75,6 @@ class QWenAttention(nn.Module):
self.attn_dropout = nn.Dropout(config.attn_dropout_prob) self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
self.softmax_in_fp32 = config.softmax_in_fp32 if hasattr(config, "softmax_in_fp32") else False self.softmax_in_fp32 = config.softmax_in_fp32 if hasattr(config, "softmax_in_fp32") else False
self.use_cache_kernel = config.use_cache_kernel if hasattr(config, "use_cache_kernel") else False
cache_dtype = torch.float cache_dtype = torch.float
self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype) self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype)
self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype) self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype)
@ -95,7 +94,6 @@ class QWenAttention(nn.Module):
self, self,
hidden_states: Optional[Tuple[torch.FloatTensor]], hidden_states: Optional[Tuple[torch.FloatTensor]],
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None, rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
): ):
mixed_x_layer = self.c_attn(hidden_states) mixed_x_layer = self.c_attn(hidden_states)
@ -112,12 +110,6 @@ class QWenAttention(nn.Module):
query = apply_rotary_pos_emb(query, q_pos_emb) query = apply_rotary_pos_emb(query, q_pos_emb)
key = apply_rotary_pos_emb(key, k_pos_emb) key = apply_rotary_pos_emb(key, k_pos_emb)
if layer_past is not None:
past_key, past_value = layer_past[0], layer_past[1]
key = torch.cat((past_key, key), dim=1)
value = torch.cat((past_value, value), dim=1)
present = (key, value) present = (key, value)
key_size = key.size(1) key_size = key.size(1)
@ -197,7 +189,6 @@ class QWenBlock(nn.Module):
self, self,
hidden_states: Optional[Tuple[torch.FloatTensor]], hidden_states: Optional[Tuple[torch.FloatTensor]],
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None, rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
): ):
layernorm_output = self.ln_1(hidden_states) layernorm_output = self.ln_1(hidden_states)
@ -205,7 +196,6 @@ class QWenBlock(nn.Module):
attn_outputs = self.attn( attn_outputs = self.attn(
layernorm_output, layernorm_output,
rotary_pos_emb_list, rotary_pos_emb_list,
layer_past=layer_past,
attention_mask=attention_mask, attention_mask=attention_mask,
) )
attn_output = attn_outputs[0] attn_output = attn_outputs[0]
@ -227,7 +217,6 @@ class QWenPreTrainedModel(PreTrainedModel):
is_parallelizable = False is_parallelizable = False
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["QWenBlock"] _no_split_modules = ["QWenBlock"]
_skip_keys_device_placement = "past_key_values"
def __init__(self, *inputs, **kwargs): def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs) super().__init__(*inputs, **kwargs)
@ -272,7 +261,6 @@ class QWenModel(QWenPreTrainedModel):
def forward( def forward(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
@ -289,9 +277,6 @@ class QWenModel(QWenPreTrainedModel):
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if past_key_values is None:
past_key_values = tuple([None] * len(self.h))
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask.view(batch_size, -1) attention_mask = attention_mask.view(batch_size, -1)
attention_mask = attention_mask[:, None, None, :] attention_mask = attention_mask[:, None, None, :]
@ -305,9 +290,6 @@ class QWenModel(QWenPreTrainedModel):
hidden_states = inputs_embeds hidden_states = inputs_embeds
kv_seq_len = hidden_states.size()[1] kv_seq_len = hidden_states.size()[1]
if past_key_values[0] is not None:
# past key values[0][0] shape: bs * seq_len * head_num * dim
kv_seq_len += past_key_values[0][0].shape[1]
if self.training or not self.use_dynamic_ntk: if self.training or not self.use_dynamic_ntk:
ntk_alpha_list = [1.0] ntk_alpha_list = [1.0]
@ -332,10 +314,9 @@ class QWenModel(QWenPreTrainedModel):
presents = () presents = ()
all_hidden_states = None all_hidden_states = None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): for i, block in enumerate(self.h):
outputs = block( outputs = block(
hidden_states, hidden_states,
layer_past=layer_past,
rotary_pos_emb_list=rotary_pos_emb_list, rotary_pos_emb_list=rotary_pos_emb_list,
attention_mask=attention_mask, attention_mask=attention_mask,
) )
@ -344,9 +325,7 @@ class QWenModel(QWenPreTrainedModel):
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape) hidden_states = hidden_states.view(output_shape)
return BaseModelOutputWithPast( return BaseModelOutputWithPast(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states
)
class QWenLMHeadModel(QWenPreTrainedModel): class QWenLMHeadModel(QWenPreTrainedModel):
@ -357,23 +336,16 @@ class QWenLMHeadModel(QWenPreTrainedModel):
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init() self.post_init()
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, **kwargs):
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
if input_ids.size(0) == 1: if input_ids.size(0) == 1:
attention_mask = None attention_mask = None
else: else:
attention_mask = kwargs.get("attention_mask", None) attention_mask = kwargs.get("attention_mask", None)
if inputs_embeds is not None and past_key_values is None: model_inputs = {"input_ids": input_ids}
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update( model_inputs.update(
{ {
"past_key_values": past_key_values,
"attention_mask": attention_mask, "attention_mask": attention_mask,
} }
) )
@ -382,7 +354,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
def forward( def forward(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
@ -390,7 +361,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
) -> Union[Tuple, CausalLMOutputWithPast]: ) -> Union[Tuple, CausalLMOutputWithPast]:
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
input_ids, input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
@ -418,7 +388,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
return CausalLMOutputWithPast( return CausalLMOutputWithPast(
loss=loss, loss=loss,
logits=lm_logits, logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states, hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
) )