Delete kv cache of qwen.
This commit is contained in:
		
							parent
							
								
									0a78627e48
								
							
						
					
					
						commit
						3233616aac
					
				| 
						 | 
				
			
			@ -75,7 +75,6 @@ class QWenAttention(nn.Module):
 | 
			
		|||
 | 
			
		||||
        self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
 | 
			
		||||
        self.softmax_in_fp32 = config.softmax_in_fp32 if hasattr(config, "softmax_in_fp32") else False
 | 
			
		||||
        self.use_cache_kernel = config.use_cache_kernel if hasattr(config, "use_cache_kernel") else False
 | 
			
		||||
        cache_dtype = torch.float
 | 
			
		||||
        self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype)
 | 
			
		||||
        self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype)
 | 
			
		||||
| 
						 | 
				
			
			@ -95,7 +94,6 @@ class QWenAttention(nn.Module):
 | 
			
		|||
        self,
 | 
			
		||||
        hidden_states: Optional[Tuple[torch.FloatTensor]],
 | 
			
		||||
        rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
 | 
			
		||||
        layer_past: Optional[Tuple[torch.Tensor]] = None,
 | 
			
		||||
        attention_mask: Optional[torch.FloatTensor] = None,
 | 
			
		||||
    ):
 | 
			
		||||
        mixed_x_layer = self.c_attn(hidden_states)
 | 
			
		||||
| 
						 | 
				
			
			@ -112,12 +110,6 @@ class QWenAttention(nn.Module):
 | 
			
		|||
        query = apply_rotary_pos_emb(query, q_pos_emb)
 | 
			
		||||
        key = apply_rotary_pos_emb(key, k_pos_emb)
 | 
			
		||||
 | 
			
		||||
        if layer_past is not None:
 | 
			
		||||
            past_key, past_value = layer_past[0], layer_past[1]
 | 
			
		||||
 | 
			
		||||
            key = torch.cat((past_key, key), dim=1)
 | 
			
		||||
            value = torch.cat((past_value, value), dim=1)
 | 
			
		||||
 | 
			
		||||
        present = (key, value)
 | 
			
		||||
 | 
			
		||||
        key_size = key.size(1)
 | 
			
		||||
| 
						 | 
				
			
			@ -197,7 +189,6 @@ class QWenBlock(nn.Module):
 | 
			
		|||
        self,
 | 
			
		||||
        hidden_states: Optional[Tuple[torch.FloatTensor]],
 | 
			
		||||
        rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
 | 
			
		||||
        layer_past: Optional[Tuple[torch.Tensor]] = None,
 | 
			
		||||
        attention_mask: Optional[torch.FloatTensor] = None,
 | 
			
		||||
    ):
 | 
			
		||||
        layernorm_output = self.ln_1(hidden_states)
 | 
			
		||||
| 
						 | 
				
			
			@ -205,7 +196,6 @@ class QWenBlock(nn.Module):
 | 
			
		|||
        attn_outputs = self.attn(
 | 
			
		||||
            layernorm_output,
 | 
			
		||||
            rotary_pos_emb_list,
 | 
			
		||||
            layer_past=layer_past,
 | 
			
		||||
            attention_mask=attention_mask,
 | 
			
		||||
        )
 | 
			
		||||
        attn_output = attn_outputs[0]
 | 
			
		||||
| 
						 | 
				
			
			@ -227,7 +217,6 @@ class QWenPreTrainedModel(PreTrainedModel):
 | 
			
		|||
    is_parallelizable = False
 | 
			
		||||
    supports_gradient_checkpointing = True
 | 
			
		||||
    _no_split_modules = ["QWenBlock"]
 | 
			
		||||
    _skip_keys_device_placement = "past_key_values"
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *inputs, **kwargs):
 | 
			
		||||
        super().__init__(*inputs, **kwargs)
 | 
			
		||||
| 
						 | 
				
			
			@ -272,7 +261,6 @@ class QWenModel(QWenPreTrainedModel):
 | 
			
		|||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
 | 
			
		||||
        attention_mask: Optional[torch.FloatTensor] = None,
 | 
			
		||||
        head_mask: Optional[torch.FloatTensor] = None,
 | 
			
		||||
        inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
			
		||||
| 
						 | 
				
			
			@ -289,9 +277,6 @@ class QWenModel(QWenPreTrainedModel):
 | 
			
		|||
        else:
 | 
			
		||||
            raise ValueError("You have to specify either input_ids or inputs_embeds")
 | 
			
		||||
 | 
			
		||||
        if past_key_values is None:
 | 
			
		||||
            past_key_values = tuple([None] * len(self.h))
 | 
			
		||||
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            attention_mask = attention_mask.view(batch_size, -1)
 | 
			
		||||
            attention_mask = attention_mask[:, None, None, :]
 | 
			
		||||
| 
						 | 
				
			
			@ -305,9 +290,6 @@ class QWenModel(QWenPreTrainedModel):
 | 
			
		|||
        hidden_states = inputs_embeds
 | 
			
		||||
 | 
			
		||||
        kv_seq_len = hidden_states.size()[1]
 | 
			
		||||
        if past_key_values[0] is not None:
 | 
			
		||||
            # past key values[0][0] shape: bs * seq_len * head_num * dim
 | 
			
		||||
            kv_seq_len += past_key_values[0][0].shape[1]
 | 
			
		||||
 | 
			
		||||
        if self.training or not self.use_dynamic_ntk:
 | 
			
		||||
            ntk_alpha_list = [1.0]
 | 
			
		||||
| 
						 | 
				
			
			@ -332,10 +314,9 @@ class QWenModel(QWenPreTrainedModel):
 | 
			
		|||
 | 
			
		||||
        presents = ()
 | 
			
		||||
        all_hidden_states = None
 | 
			
		||||
        for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
 | 
			
		||||
        for i, block in enumerate(self.h):
 | 
			
		||||
            outputs = block(
 | 
			
		||||
                hidden_states,
 | 
			
		||||
                layer_past=layer_past,
 | 
			
		||||
                rotary_pos_emb_list=rotary_pos_emb_list,
 | 
			
		||||
                attention_mask=attention_mask,
 | 
			
		||||
            )
 | 
			
		||||
| 
						 | 
				
			
			@ -344,9 +325,7 @@ class QWenModel(QWenPreTrainedModel):
 | 
			
		|||
 | 
			
		||||
        hidden_states = self.ln_f(hidden_states)
 | 
			
		||||
        hidden_states = hidden_states.view(output_shape)
 | 
			
		||||
        return BaseModelOutputWithPast(
 | 
			
		||||
            last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states
 | 
			
		||||
        )
 | 
			
		||||
        return BaseModelOutputWithPast(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class QWenLMHeadModel(QWenPreTrainedModel):
 | 
			
		||||
| 
						 | 
				
			
			@ -357,23 +336,16 @@ class QWenLMHeadModel(QWenPreTrainedModel):
 | 
			
		|||
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 | 
			
		||||
        self.post_init()
 | 
			
		||||
 | 
			
		||||
    def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
 | 
			
		||||
        if past_key_values:
 | 
			
		||||
            input_ids = input_ids[:, -1].unsqueeze(-1)
 | 
			
		||||
 | 
			
		||||
    def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, **kwargs):
 | 
			
		||||
        if input_ids.size(0) == 1:
 | 
			
		||||
            attention_mask = None
 | 
			
		||||
        else:
 | 
			
		||||
            attention_mask = kwargs.get("attention_mask", None)
 | 
			
		||||
 | 
			
		||||
        if inputs_embeds is not None and past_key_values is None:
 | 
			
		||||
            model_inputs = {"inputs_embeds": inputs_embeds}
 | 
			
		||||
        else:
 | 
			
		||||
            model_inputs = {"input_ids": input_ids}
 | 
			
		||||
        model_inputs = {"input_ids": input_ids}
 | 
			
		||||
 | 
			
		||||
        model_inputs.update(
 | 
			
		||||
            {
 | 
			
		||||
                "past_key_values": past_key_values,
 | 
			
		||||
                "attention_mask": attention_mask,
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
| 
						 | 
				
			
			@ -382,7 +354,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
 | 
			
		|||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
 | 
			
		||||
        attention_mask: Optional[torch.FloatTensor] = None,
 | 
			
		||||
        head_mask: Optional[torch.FloatTensor] = None,
 | 
			
		||||
        inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
			
		||||
| 
						 | 
				
			
			@ -390,7 +361,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
 | 
			
		|||
    ) -> Union[Tuple, CausalLMOutputWithPast]:
 | 
			
		||||
        transformer_outputs = self.transformer(
 | 
			
		||||
            input_ids,
 | 
			
		||||
            past_key_values=past_key_values,
 | 
			
		||||
            attention_mask=attention_mask,
 | 
			
		||||
            head_mask=head_mask,
 | 
			
		||||
            inputs_embeds=inputs_embeds,
 | 
			
		||||
| 
						 | 
				
			
			@ -418,7 +388,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
 | 
			
		|||
        return CausalLMOutputWithPast(
 | 
			
		||||
            loss=loss,
 | 
			
		||||
            logits=lm_logits,
 | 
			
		||||
            past_key_values=transformer_outputs.past_key_values,
 | 
			
		||||
            hidden_states=transformer_outputs.hidden_states,
 | 
			
		||||
            attentions=transformer_outputs.attentions,
 | 
			
		||||
        )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue