Refine model of qwen for long sequence in eval.
This commit is contained in:
		
							parent
							
								
									45c2f532ff
								
							
						
					
					
						commit
						f96bcc799c
					
				
							
								
								
									
										15
									
								
								qwen/demo.py
								
								
								
								
							
							
						
						
									
										15
									
								
								qwen/demo.py
								
								
								
								
							| 
						 | 
					@ -52,11 +52,10 @@ print(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
 | 
					tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
 | 
				
			||||||
model = model.from_pretrained(
 | 
					model = model.from_pretrained(model_dir, config=config, device_map="auto", trust_remote_code=True)
 | 
				
			||||||
    model_dir, config=config, device_map="auto", trust_remote_code=True
 | 
					
 | 
				
			||||||
).train()
 | 
					# model = model.eval()
 | 
				
			||||||
# model.train()
 | 
					model = model.train() # control by @torch.no_grad()
 | 
				
			||||||
# model.zero_grad()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
# 可指定不同的生成长度、top_p等相关超参
 | 
					# 可指定不同的生成长度、top_p等相关超参
 | 
				
			||||||
# model.generation_config = GenerationConfig.from_pretrained(
 | 
					# model.generation_config = GenerationConfig.from_pretrained(
 | 
				
			||||||
| 
						 | 
					@ -74,16 +73,14 @@ print(decode_tokens)
 | 
				
			||||||
# 日本的首都东京。<|im_end|><|endoftext|>
 | 
					# 日本的首都东京。<|im_end|><|endoftext|>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# # 第一轮对话
 | 
					# # 第一轮对话
 | 
				
			||||||
# response, history, decode_tokens = model.chat(tokenizer, "你好", "", history=None)
 | 
					# response, history, decode_tokens = model.chat(tokenizer, "你好", "", history=None)
 | 
				
			||||||
# print(decode_tokens)
 | 
					# print(decode_tokens)
 | 
				
			||||||
# # 你好!很高兴为你提供帮助。
 | 
					# # 你好!很高兴为你提供帮助。
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# 第二轮对话
 | 
					# 第二轮对话
 | 
				
			||||||
# response, history = model.chat(tokenizer, "给我讲一个年轻人奋斗创业最终取得成功的故事。", history=None)
 | 
					response, history, decode_tokens = model.chat(tokenizer, "给我讲一个年轻人奋斗创业最终取得成功的故事。", "", history=None)
 | 
				
			||||||
# print(response)
 | 
					print(response)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# <|im_start|>system
 | 
					# <|im_start|>system
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -41,8 +41,10 @@ import sys
 | 
				
			||||||
 | 
					
 | 
				
			||||||
sys.path.append("..")
 | 
					sys.path.append("..")
 | 
				
			||||||
from tools import show
 | 
					from tools import show
 | 
				
			||||||
 | 
					from tools import mem_tracker
 | 
				
			||||||
 | 
					
 | 
				
			||||||
logger = logging.get_logger(__name__)
 | 
					# tracker = mem_tracker.MemTracker()
 | 
				
			||||||
 | 
					# tracker.track()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class QWenAttention(nn.Module):
 | 
					class QWenAttention(nn.Module):
 | 
				
			||||||
| 
						 | 
					@ -110,8 +112,6 @@ class QWenAttention(nn.Module):
 | 
				
			||||||
        query = apply_rotary_pos_emb(query, q_pos_emb)
 | 
					        query = apply_rotary_pos_emb(query, q_pos_emb)
 | 
				
			||||||
        key = apply_rotary_pos_emb(key, k_pos_emb)
 | 
					        key = apply_rotary_pos_emb(key, k_pos_emb)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        present = (key, value)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        key_size = key.size(1)
 | 
					        key_size = key.size(1)
 | 
				
			||||||
        if key_size > self.seq_length and not self.training:
 | 
					        if key_size > self.seq_length and not self.training:
 | 
				
			||||||
            seq_start = key.size(1) - query.size(1)
 | 
					            seq_start = key.size(1) - query.size(1)
 | 
				
			||||||
| 
						 | 
					@ -148,8 +148,8 @@ class QWenAttention(nn.Module):
 | 
				
			||||||
        attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask).transpose(1, 2)
 | 
					        attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask).transpose(1, 2)
 | 
				
			||||||
        context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim)
 | 
					        context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim)
 | 
				
			||||||
        attn_output = self.c_proj(context_layer)
 | 
					        attn_output = self.c_proj(context_layer)
 | 
				
			||||||
        outputs = (attn_output, present)
 | 
					
 | 
				
			||||||
        return outputs
 | 
					        return attn_output
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class QWenMLP(nn.Module):
 | 
					class QWenMLP(nn.Module):
 | 
				
			||||||
| 
						 | 
					@ -199,7 +199,6 @@ class QWenBlock(nn.Module):
 | 
				
			||||||
            attention_mask=attention_mask,
 | 
					            attention_mask=attention_mask,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        attn_output = attn_outputs[0]
 | 
					        attn_output = attn_outputs[0]
 | 
				
			||||||
        outputs = attn_outputs[1:]
 | 
					 | 
				
			||||||
        residual = hidden_states
 | 
					        residual = hidden_states
 | 
				
			||||||
        layernorm_input = attn_output + residual
 | 
					        layernorm_input = attn_output + residual
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -207,8 +206,7 @@ class QWenBlock(nn.Module):
 | 
				
			||||||
        residual = layernorm_input
 | 
					        residual = layernorm_input
 | 
				
			||||||
        mlp_output = self.mlp(layernorm_output)
 | 
					        mlp_output = self.mlp(layernorm_output)
 | 
				
			||||||
        hidden_states = residual + mlp_output
 | 
					        hidden_states = residual + mlp_output
 | 
				
			||||||
        outputs = (hidden_states,) + outputs
 | 
					        return hidden_states
 | 
				
			||||||
        return outputs
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class QWenPreTrainedModel(PreTrainedModel):
 | 
					class QWenPreTrainedModel(PreTrainedModel):
 | 
				
			||||||
| 
						 | 
					@ -312,16 +310,13 @@ class QWenModel(QWenPreTrainedModel):
 | 
				
			||||||
        hidden_states = self.drop(hidden_states)
 | 
					        hidden_states = self.drop(hidden_states)
 | 
				
			||||||
        output_shape = input_shape + (hidden_states.size(-1),)
 | 
					        output_shape = input_shape + (hidden_states.size(-1),)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        presents = ()
 | 
					 | 
				
			||||||
        all_hidden_states = None
 | 
					        all_hidden_states = None
 | 
				
			||||||
        for i, block in enumerate(self.h):
 | 
					        for block in self.h:
 | 
				
			||||||
            outputs = block(
 | 
					            hidden_states = block(
 | 
				
			||||||
                hidden_states,
 | 
					                hidden_states,
 | 
				
			||||||
                rotary_pos_emb_list=rotary_pos_emb_list,
 | 
					                rotary_pos_emb_list=rotary_pos_emb_list,
 | 
				
			||||||
                attention_mask=attention_mask,
 | 
					                attention_mask=attention_mask,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            hidden_states = outputs[0]
 | 
					 | 
				
			||||||
            presents = presents + (outputs[1],)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        hidden_states = self.ln_f(hidden_states)
 | 
					        hidden_states = self.ln_f(hidden_states)
 | 
				
			||||||
        hidden_states = hidden_states.view(output_shape)
 | 
					        hidden_states = hidden_states.view(output_shape)
 | 
				
			||||||
| 
						 | 
					@ -392,6 +387,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
 | 
				
			||||||
            attentions=transformer_outputs.attentions,
 | 
					            attentions=transformer_outputs.attentions,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @torch.no_grad()
 | 
				
			||||||
    def chat(
 | 
					    def chat(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
        tokenizer: PreTrainedTokenizer,
 | 
					        tokenizer: PreTrainedTokenizer,
 | 
				
			||||||
| 
						 | 
					@ -454,15 +450,9 @@ class QWenLMHeadModel(QWenPreTrainedModel):
 | 
				
			||||||
        # 2. Set generation parameters if not already defined
 | 
					        # 2. Set generation parameters if not already defined
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
 | 
					        if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
 | 
				
			||||||
            if model_kwargs.get("attention_mask", None) is None:
 | 
					 | 
				
			||||||
                logger.warning(
 | 
					 | 
				
			||||||
                    "The attention mask and the pad token id were not set. As a consequence, you may observe "
 | 
					 | 
				
			||||||
                    "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
            eos_token_id = generation_config.eos_token_id
 | 
					            eos_token_id = generation_config.eos_token_id
 | 
				
			||||||
            if isinstance(eos_token_id, list):
 | 
					            if isinstance(eos_token_id, list):
 | 
				
			||||||
                eos_token_id = eos_token_id[0]
 | 
					                eos_token_id = eos_token_id[0]
 | 
				
			||||||
            logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
 | 
					 | 
				
			||||||
            generation_config.pad_token_id = eos_token_id
 | 
					            generation_config.pad_token_id = eos_token_id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # 3. Define model inputs
 | 
					        # 3. Define model inputs
 | 
				
			||||||
| 
						 | 
					@ -571,7 +561,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if this_peer_finished:
 | 
					            if this_peer_finished:
 | 
				
			||||||
                break
 | 
					                break
 | 
				
			||||||
 | 
					 | 
				
			||||||
        return input_ids
 | 
					        return input_ids
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue