diff --git a/qwen/modeling_qwen.py b/qwen/modeling_qwen.py index 631f4fb..44f8806 100644 --- a/qwen/modeling_qwen.py +++ b/qwen/modeling_qwen.py @@ -75,7 +75,6 @@ class QWenAttention(nn.Module): self.attn_dropout = nn.Dropout(config.attn_dropout_prob) self.softmax_in_fp32 = config.softmax_in_fp32 if hasattr(config, "softmax_in_fp32") else False - self.use_cache_kernel = config.use_cache_kernel if hasattr(config, "use_cache_kernel") else False cache_dtype = torch.float self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype) self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype) @@ -95,7 +94,6 @@ class QWenAttention(nn.Module): self, hidden_states: Optional[Tuple[torch.FloatTensor]], rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None, - layer_past: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.FloatTensor] = None, ): mixed_x_layer = self.c_attn(hidden_states) @@ -112,12 +110,6 @@ class QWenAttention(nn.Module): query = apply_rotary_pos_emb(query, q_pos_emb) key = apply_rotary_pos_emb(key, k_pos_emb) - if layer_past is not None: - past_key, past_value = layer_past[0], layer_past[1] - - key = torch.cat((past_key, key), dim=1) - value = torch.cat((past_value, value), dim=1) - present = (key, value) key_size = key.size(1) @@ -197,7 +189,6 @@ class QWenBlock(nn.Module): self, hidden_states: Optional[Tuple[torch.FloatTensor]], rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None, - layer_past: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.FloatTensor] = None, ): layernorm_output = self.ln_1(hidden_states) @@ -205,7 +196,6 @@ class QWenBlock(nn.Module): attn_outputs = self.attn( layernorm_output, rotary_pos_emb_list, - layer_past=layer_past, attention_mask=attention_mask, ) attn_output = attn_outputs[0] @@ -227,7 +217,6 @@ class QWenPreTrainedModel(PreTrainedModel): is_parallelizable = False supports_gradient_checkpointing = True _no_split_modules = ["QWenBlock"] - _skip_keys_device_placement = "past_key_values" def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) @@ -272,7 +261,6 @@ class QWenModel(QWenPreTrainedModel): def forward( self, input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, @@ -289,9 +277,6 @@ class QWenModel(QWenPreTrainedModel): else: raise ValueError("You have to specify either input_ids or inputs_embeds") - if past_key_values is None: - past_key_values = tuple([None] * len(self.h)) - if attention_mask is not None: attention_mask = attention_mask.view(batch_size, -1) attention_mask = attention_mask[:, None, None, :] @@ -305,9 +290,6 @@ class QWenModel(QWenPreTrainedModel): hidden_states = inputs_embeds kv_seq_len = hidden_states.size()[1] - if past_key_values[0] is not None: - # past key values[0][0] shape: bs * seq_len * head_num * dim - kv_seq_len += past_key_values[0][0].shape[1] if self.training or not self.use_dynamic_ntk: ntk_alpha_list = [1.0] @@ -332,10 +314,9 @@ class QWenModel(QWenPreTrainedModel): presents = () all_hidden_states = None - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + for i, block in enumerate(self.h): outputs = block( hidden_states, - layer_past=layer_past, rotary_pos_emb_list=rotary_pos_emb_list, attention_mask=attention_mask, ) @@ -344,9 +325,7 @@ class QWenModel(QWenPreTrainedModel): hidden_states = self.ln_f(hidden_states) hidden_states = hidden_states.view(output_shape) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states - ) + return BaseModelOutputWithPast(last_hidden_state=hidden_states, hidden_states=all_hidden_states) class QWenLMHeadModel(QWenPreTrainedModel): @@ -357,23 +336,16 @@ class QWenLMHeadModel(QWenPreTrainedModel): self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.post_init() - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): - if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) - + def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, **kwargs): if input_ids.size(0) == 1: attention_mask = None else: attention_mask = kwargs.get("attention_mask", None) - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} + model_inputs = {"input_ids": input_ids} model_inputs.update( { - "past_key_values": past_key_values, "attention_mask": attention_mask, } ) @@ -382,7 +354,6 @@ class QWenLMHeadModel(QWenPreTrainedModel): def forward( self, input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, @@ -390,7 +361,6 @@ class QWenLMHeadModel(QWenPreTrainedModel): ) -> Union[Tuple, CausalLMOutputWithPast]: transformer_outputs = self.transformer( input_ids, - past_key_values=past_key_values, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, @@ -418,7 +388,6 @@ class QWenLMHeadModel(QWenPreTrainedModel): return CausalLMOutputWithPast( loss=loss, logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, )