Refine qwen model.

This commit is contained in:
Colin 2024-01-13 16:50:25 +08:00
parent 9386d044b6
commit 5cf6e8b013
2 changed files with 33 additions and 33 deletions

View File

@ -74,6 +74,16 @@ print(decode_tokens)
# <|im_start|>system
# You are a helpful assistant.<|im_end|>
# <|im_start|>user
# 你好<|im_end|>
# <|im_start|>assistant
# 莎是现代汉语的男性的名字,出自《诗经》中的“采采卷耳
# <|im_start|>system
# You are a helpful assistant.<|im_end|>
# <|im_start|>user

View File

@ -37,11 +37,15 @@ from qwen_generation_utils import (
StopWordsLogitsProcessor,
)
import sys
sys.path.append("..")
from tools import show
logger = logging.get_logger(__name__)
class QWenAttention(nn.Module):
def __init__(self, config):
def __init__(self, config, index):
super().__init__()
self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
@ -74,6 +78,7 @@ class QWenAttention(nn.Module):
cache_dtype = torch.float
self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype)
self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype)
self.index = index
def _split_heads(self, tensor, num_heads, attn_head_size):
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
@ -139,6 +144,11 @@ class QWenAttention(nn.Module):
else:
attention_mask = causal_mask
# qk = query @ key.transpose(-2, -1)
# qk = qk[0]
# show.DumpTensorToImage(qk,"q_matmul_k_layer_"+str(self.index)+".png")
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask).transpose(1, 2)
context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(context_layer)
@ -163,7 +173,7 @@ class QWenMLP(nn.Module):
class QWenBlock(nn.Module):
def __init__(self, config):
def __init__(self, config, index):
super().__init__()
hidden_size = config.hidden_size
@ -171,12 +181,13 @@ class QWenBlock(nn.Module):
hidden_size,
eps=config.layer_norm_epsilon,
)
self.attn = QWenAttention(config)
self.attn = QWenAttention(config, index)
self.ln_2 = RMSNorm(
hidden_size,
eps=config.layer_norm_epsilon,
)
self.mlp = QWenMLP(config)
self.index = index
def forward(
self,
@ -240,7 +251,7 @@ class QWenModel(QWenPreTrainedModel):
dim = self.rotary_ndims if self.rotary_ndims is not None else config.kv_channels
self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)
self.h = nn.ModuleList([QWenBlock(config) for i in range(config.num_hidden_layers)])
self.h = nn.ModuleList([QWenBlock(config, i) for i in range(config.num_hidden_layers)])
self.ln_f = RMSNorm(
self.embed_dim,
eps=config.layer_norm_epsilon,
@ -460,7 +471,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
inputs: Optional[torch.Tensor] = None,
stop_words_ids = [],
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
streamer: Optional["BaseStreamer"] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
generation_config = self.generation_config
@ -508,9 +518,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
# 5. Prepare `input_ids` which will be used for auto-regressive generation
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
if streamer is not None:
streamer.put(input_ids.cpu())
# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
@ -546,10 +553,8 @@ class QWenLMHeadModel(QWenPreTrainedModel):
# 13. run sample
pad_token_id=generation_config.pad_token_id
eos_token_id=generation_config.eos_token_id
streamer=streamer
eos_token_id_tensor=torch.tensor([generation_config.eos_token_id]).to(input_ids.device)
# init values
stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=StoppingCriteriaList()
@ -557,12 +562,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
logits_warper = self._get_logits_warper(generation_config)
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
# init attention / hidden states / scores tuples
scores = None
@ -588,25 +587,19 @@ class QWenLMHeadModel(QWenPreTrainedModel):
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False)
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
)
unfinished_sequences = unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
)
# stop when each sentence is finished
if unfinished_sequences.max() == 0:
this_peer_finished = True
# stop when each sentence is finished
if unfinished_sequences.max() == 0:
this_peer_finished = True
# stop if we exceed the maximum length
if stopping_criteria(input_ids, scores):
@ -615,9 +608,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
if this_peer_finished:
break
if streamer is not None:
streamer.end()
return input_ids