Refine qwen model.
This commit is contained in:
		
							parent
							
								
									9386d044b6
								
							
						
					
					
						commit
						5cf6e8b013
					
				
							
								
								
									
										10
									
								
								qwen/demo.py
								
								
								
								
							
							
						
						
									
										10
									
								
								qwen/demo.py
								
								
								
								
							| 
						 | 
				
			
			@ -74,6 +74,16 @@ print(decode_tokens)
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# <|im_start|>system
 | 
			
		||||
# You are a helpful assistant.<|im_end|>
 | 
			
		||||
# <|im_start|>user
 | 
			
		||||
# 你好<|im_end|>
 | 
			
		||||
# <|im_start|>assistant
 | 
			
		||||
# 莎是现代汉语的男性的名字,出自《诗经》中的“采采卷耳
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# <|im_start|>system
 | 
			
		||||
# You are a helpful assistant.<|im_end|>
 | 
			
		||||
# <|im_start|>user
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -37,11 +37,15 @@ from qwen_generation_utils import (
 | 
			
		|||
    StopWordsLogitsProcessor,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
import sys
 | 
			
		||||
sys.path.append("..")
 | 
			
		||||
from tools import show
 | 
			
		||||
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class QWenAttention(nn.Module):
 | 
			
		||||
    def __init__(self, config):
 | 
			
		||||
    def __init__(self, config, index):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
 | 
			
		||||
| 
						 | 
				
			
			@ -74,6 +78,7 @@ class QWenAttention(nn.Module):
 | 
			
		|||
        cache_dtype = torch.float
 | 
			
		||||
        self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype)
 | 
			
		||||
        self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype)
 | 
			
		||||
        self.index = index
 | 
			
		||||
 | 
			
		||||
    def _split_heads(self, tensor, num_heads, attn_head_size):
 | 
			
		||||
        new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
 | 
			
		||||
| 
						 | 
				
			
			@ -139,6 +144,11 @@ class QWenAttention(nn.Module):
 | 
			
		|||
        else:
 | 
			
		||||
            attention_mask = causal_mask
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        # qk = query @ key.transpose(-2, -1)
 | 
			
		||||
        # qk = qk[0]
 | 
			
		||||
        # show.DumpTensorToImage(qk,"q_matmul_k_layer_"+str(self.index)+".png")
 | 
			
		||||
 | 
			
		||||
        attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask).transpose(1, 2)
 | 
			
		||||
        context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim)
 | 
			
		||||
        attn_output = self.c_proj(context_layer)
 | 
			
		||||
| 
						 | 
				
			
			@ -163,7 +173,7 @@ class QWenMLP(nn.Module):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class QWenBlock(nn.Module):
 | 
			
		||||
    def __init__(self, config):
 | 
			
		||||
    def __init__(self, config, index):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        hidden_size = config.hidden_size
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -171,12 +181,13 @@ class QWenBlock(nn.Module):
 | 
			
		|||
            hidden_size,
 | 
			
		||||
            eps=config.layer_norm_epsilon,
 | 
			
		||||
        )
 | 
			
		||||
        self.attn = QWenAttention(config)
 | 
			
		||||
        self.attn = QWenAttention(config, index)
 | 
			
		||||
        self.ln_2 = RMSNorm(
 | 
			
		||||
            hidden_size,
 | 
			
		||||
            eps=config.layer_norm_epsilon,
 | 
			
		||||
        )
 | 
			
		||||
        self.mlp = QWenMLP(config)
 | 
			
		||||
        self.index = index
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
| 
						 | 
				
			
			@ -240,7 +251,7 @@ class QWenModel(QWenPreTrainedModel):
 | 
			
		|||
        dim = self.rotary_ndims if self.rotary_ndims is not None else config.kv_channels
 | 
			
		||||
        self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)
 | 
			
		||||
 | 
			
		||||
        self.h = nn.ModuleList([QWenBlock(config) for i in range(config.num_hidden_layers)])
 | 
			
		||||
        self.h = nn.ModuleList([QWenBlock(config, i) for i in range(config.num_hidden_layers)])
 | 
			
		||||
        self.ln_f = RMSNorm(
 | 
			
		||||
            self.embed_dim,
 | 
			
		||||
            eps=config.layer_norm_epsilon,
 | 
			
		||||
| 
						 | 
				
			
			@ -460,7 +471,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
 | 
			
		|||
        inputs: Optional[torch.Tensor] = None,
 | 
			
		||||
        stop_words_ids = [],
 | 
			
		||||
        prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
 | 
			
		||||
        streamer: Optional["BaseStreamer"] = None,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> Union[GenerateOutput, torch.LongTensor]:
 | 
			
		||||
        generation_config = self.generation_config
 | 
			
		||||
| 
						 | 
				
			
			@ -508,9 +518,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
 | 
			
		|||
        # 5. Prepare `input_ids` which will be used for auto-regressive generation
 | 
			
		||||
        input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
 | 
			
		||||
 | 
			
		||||
        if streamer is not None:
 | 
			
		||||
            streamer.put(input_ids.cpu())
 | 
			
		||||
 | 
			
		||||
        # 6. Prepare `max_length` depending on other stopping criteria.
 | 
			
		||||
        input_ids_length = input_ids.shape[-1]
 | 
			
		||||
        has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
 | 
			
		||||
| 
						 | 
				
			
			@ -546,10 +553,8 @@ class QWenLMHeadModel(QWenPreTrainedModel):
 | 
			
		|||
        # 13. run sample
 | 
			
		||||
 | 
			
		||||
        pad_token_id=generation_config.pad_token_id
 | 
			
		||||
        eos_token_id=generation_config.eos_token_id
 | 
			
		||||
        streamer=streamer
 | 
			
		||||
        eos_token_id_tensor=torch.tensor([generation_config.eos_token_id]).to(input_ids.device)
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        # init values
 | 
			
		||||
        stopping_criteria = self._get_stopping_criteria(
 | 
			
		||||
            generation_config=generation_config, stopping_criteria=StoppingCriteriaList()
 | 
			
		||||
| 
						 | 
				
			
			@ -557,12 +562,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
 | 
			
		|||
 | 
			
		||||
        logits_warper = self._get_logits_warper(generation_config)
 | 
			
		||||
 | 
			
		||||
        pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
 | 
			
		||||
        eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
 | 
			
		||||
        if isinstance(eos_token_id, int):
 | 
			
		||||
            eos_token_id = [eos_token_id]
 | 
			
		||||
        eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
 | 
			
		||||
 | 
			
		||||
        # init attention / hidden states / scores tuples
 | 
			
		||||
        scores = None
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -588,25 +587,19 @@ class QWenLMHeadModel(QWenPreTrainedModel):
 | 
			
		|||
            probs = nn.functional.softmax(next_token_scores, dim=-1)
 | 
			
		||||
            next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
 | 
			
		||||
 | 
			
		||||
            # finished sentences should have their next token be a padding token
 | 
			
		||||
            if eos_token_id is not None:
 | 
			
		||||
                next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
 | 
			
		||||
            next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
 | 
			
		||||
 | 
			
		||||
            # update generated ids, model inputs, and length for next step
 | 
			
		||||
            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
 | 
			
		||||
            if streamer is not None:
 | 
			
		||||
                streamer.put(next_tokens.cpu())
 | 
			
		||||
            model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False)
 | 
			
		||||
 | 
			
		||||
            # if eos_token was found in one sentence, set sentence to finished
 | 
			
		||||
            if eos_token_id_tensor is not None:
 | 
			
		||||
                unfinished_sequences = unfinished_sequences.mul(
 | 
			
		||||
                    next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
 | 
			
		||||
                )
 | 
			
		||||
            unfinished_sequences = unfinished_sequences.mul(
 | 
			
		||||
                next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
                # stop when each sentence is finished
 | 
			
		||||
                if unfinished_sequences.max() == 0:
 | 
			
		||||
                    this_peer_finished = True
 | 
			
		||||
            # stop when each sentence is finished
 | 
			
		||||
            if unfinished_sequences.max() == 0:
 | 
			
		||||
                this_peer_finished = True
 | 
			
		||||
 | 
			
		||||
            # stop if we exceed the maximum length
 | 
			
		||||
            if stopping_criteria(input_ids, scores):
 | 
			
		||||
| 
						 | 
				
			
			@ -615,9 +608,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
 | 
			
		|||
            if this_peer_finished:
 | 
			
		||||
                break
 | 
			
		||||
 | 
			
		||||
        if streamer is not None:
 | 
			
		||||
            streamer.end()
 | 
			
		||||
 | 
			
		||||
        return input_ids
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue