Delete kv cache of qwen.
This commit is contained in:
parent
0a78627e48
commit
3233616aac
qwen
|
@ -75,7 +75,6 @@ class QWenAttention(nn.Module):
|
||||||
|
|
||||||
self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
|
self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
|
||||||
self.softmax_in_fp32 = config.softmax_in_fp32 if hasattr(config, "softmax_in_fp32") else False
|
self.softmax_in_fp32 = config.softmax_in_fp32 if hasattr(config, "softmax_in_fp32") else False
|
||||||
self.use_cache_kernel = config.use_cache_kernel if hasattr(config, "use_cache_kernel") else False
|
|
||||||
cache_dtype = torch.float
|
cache_dtype = torch.float
|
||||||
self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype)
|
self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype)
|
||||||
self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype)
|
self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype)
|
||||||
|
@ -95,7 +94,6 @@ class QWenAttention(nn.Module):
|
||||||
self,
|
self,
|
||||||
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
||||||
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
|
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
|
||||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
):
|
):
|
||||||
mixed_x_layer = self.c_attn(hidden_states)
|
mixed_x_layer = self.c_attn(hidden_states)
|
||||||
|
@ -112,12 +110,6 @@ class QWenAttention(nn.Module):
|
||||||
query = apply_rotary_pos_emb(query, q_pos_emb)
|
query = apply_rotary_pos_emb(query, q_pos_emb)
|
||||||
key = apply_rotary_pos_emb(key, k_pos_emb)
|
key = apply_rotary_pos_emb(key, k_pos_emb)
|
||||||
|
|
||||||
if layer_past is not None:
|
|
||||||
past_key, past_value = layer_past[0], layer_past[1]
|
|
||||||
|
|
||||||
key = torch.cat((past_key, key), dim=1)
|
|
||||||
value = torch.cat((past_value, value), dim=1)
|
|
||||||
|
|
||||||
present = (key, value)
|
present = (key, value)
|
||||||
|
|
||||||
key_size = key.size(1)
|
key_size = key.size(1)
|
||||||
|
@ -197,7 +189,6 @@ class QWenBlock(nn.Module):
|
||||||
self,
|
self,
|
||||||
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
||||||
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
|
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
|
||||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
):
|
):
|
||||||
layernorm_output = self.ln_1(hidden_states)
|
layernorm_output = self.ln_1(hidden_states)
|
||||||
|
@ -205,7 +196,6 @@ class QWenBlock(nn.Module):
|
||||||
attn_outputs = self.attn(
|
attn_outputs = self.attn(
|
||||||
layernorm_output,
|
layernorm_output,
|
||||||
rotary_pos_emb_list,
|
rotary_pos_emb_list,
|
||||||
layer_past=layer_past,
|
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
)
|
)
|
||||||
attn_output = attn_outputs[0]
|
attn_output = attn_outputs[0]
|
||||||
|
@ -227,7 +217,6 @@ class QWenPreTrainedModel(PreTrainedModel):
|
||||||
is_parallelizable = False
|
is_parallelizable = False
|
||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_no_split_modules = ["QWenBlock"]
|
_no_split_modules = ["QWenBlock"]
|
||||||
_skip_keys_device_placement = "past_key_values"
|
|
||||||
|
|
||||||
def __init__(self, *inputs, **kwargs):
|
def __init__(self, *inputs, **kwargs):
|
||||||
super().__init__(*inputs, **kwargs)
|
super().__init__(*inputs, **kwargs)
|
||||||
|
@ -272,7 +261,6 @@ class QWenModel(QWenPreTrainedModel):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
@ -289,9 +277,6 @@ class QWenModel(QWenPreTrainedModel):
|
||||||
else:
|
else:
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
if past_key_values is None:
|
|
||||||
past_key_values = tuple([None] * len(self.h))
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attention_mask = attention_mask.view(batch_size, -1)
|
attention_mask = attention_mask.view(batch_size, -1)
|
||||||
attention_mask = attention_mask[:, None, None, :]
|
attention_mask = attention_mask[:, None, None, :]
|
||||||
|
@ -305,9 +290,6 @@ class QWenModel(QWenPreTrainedModel):
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
kv_seq_len = hidden_states.size()[1]
|
kv_seq_len = hidden_states.size()[1]
|
||||||
if past_key_values[0] is not None:
|
|
||||||
# past key values[0][0] shape: bs * seq_len * head_num * dim
|
|
||||||
kv_seq_len += past_key_values[0][0].shape[1]
|
|
||||||
|
|
||||||
if self.training or not self.use_dynamic_ntk:
|
if self.training or not self.use_dynamic_ntk:
|
||||||
ntk_alpha_list = [1.0]
|
ntk_alpha_list = [1.0]
|
||||||
|
@ -332,10 +314,9 @@ class QWenModel(QWenPreTrainedModel):
|
||||||
|
|
||||||
presents = ()
|
presents = ()
|
||||||
all_hidden_states = None
|
all_hidden_states = None
|
||||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
for i, block in enumerate(self.h):
|
||||||
outputs = block(
|
outputs = block(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
layer_past=layer_past,
|
|
||||||
rotary_pos_emb_list=rotary_pos_emb_list,
|
rotary_pos_emb_list=rotary_pos_emb_list,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
)
|
)
|
||||||
|
@ -344,9 +325,7 @@ class QWenModel(QWenPreTrainedModel):
|
||||||
|
|
||||||
hidden_states = self.ln_f(hidden_states)
|
hidden_states = self.ln_f(hidden_states)
|
||||||
hidden_states = hidden_states.view(output_shape)
|
hidden_states = hidden_states.view(output_shape)
|
||||||
return BaseModelOutputWithPast(
|
return BaseModelOutputWithPast(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
|
||||||
last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class QWenLMHeadModel(QWenPreTrainedModel):
|
class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
|
@ -357,23 +336,16 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, **kwargs):
|
||||||
if past_key_values:
|
|
||||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
|
||||||
|
|
||||||
if input_ids.size(0) == 1:
|
if input_ids.size(0) == 1:
|
||||||
attention_mask = None
|
attention_mask = None
|
||||||
else:
|
else:
|
||||||
attention_mask = kwargs.get("attention_mask", None)
|
attention_mask = kwargs.get("attention_mask", None)
|
||||||
|
|
||||||
if inputs_embeds is not None and past_key_values is None:
|
model_inputs = {"input_ids": input_ids}
|
||||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
|
||||||
else:
|
|
||||||
model_inputs = {"input_ids": input_ids}
|
|
||||||
|
|
||||||
model_inputs.update(
|
model_inputs.update(
|
||||||
{
|
{
|
||||||
"past_key_values": past_key_values,
|
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@ -382,7 +354,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
@ -390,7 +361,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
transformer_outputs = self.transformer(
|
transformer_outputs = self.transformer(
|
||||||
input_ids,
|
input_ids,
|
||||||
past_key_values=past_key_values,
|
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
|
@ -418,7 +388,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
return CausalLMOutputWithPast(
|
return CausalLMOutputWithPast(
|
||||||
loss=loss,
|
loss=loss,
|
||||||
logits=lm_logits,
|
logits=lm_logits,
|
||||||
past_key_values=transformer_outputs.past_key_values,
|
|
||||||
hidden_states=transformer_outputs.hidden_states,
|
hidden_states=transformer_outputs.hidden_states,
|
||||||
attentions=transformer_outputs.attentions,
|
attentions=transformer_outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue