Format model of qwen.

This commit is contained in:
Colin 2024-01-13 17:16:43 +08:00
parent 5cf6e8b013
commit d13f7e6c57
1 changed files with 6 additions and 15 deletions

View File

@ -38,6 +38,7 @@ from qwen_generation_utils import (
)
import sys
sys.path.append("..")
from tools import show
@ -144,7 +145,6 @@ class QWenAttention(nn.Module):
else:
attention_mask = causal_mask
# qk = query @ key.transpose(-2, -1)
# qk = qk[0]
# show.DumpTensorToImage(qk,"q_matmul_k_layer_"+str(self.index)+".png")
@ -441,12 +441,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
if max_window_size is None:
max_window_size = generation_config.max_window_size
raw_text, context_tokens = make_context(
tokenizer,
query,
query_assistant,
history=history,
system=system,
max_window_size=max_window_size
tokenizer, query, query_assistant, history=history, system=system, max_window_size=max_window_size
)
stop_words_ids.extend([[tokenizer.im_end_id], [tokenizer.im_start_id]])
@ -469,13 +464,12 @@ class QWenLMHeadModel(QWenPreTrainedModel):
def generate(
self,
inputs: Optional[torch.Tensor] = None,
stop_words_ids = [],
stop_words_ids=[],
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
generation_config = self.generation_config
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self._validate_model_class()
@ -486,7 +480,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
# 2. Set generation parameters if not already defined
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
if model_kwargs.get("attention_mask", None) is None:
logger.warning(
@ -524,7 +517,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
generation_config.max_length = generation_config.max_new_tokens + input_ids_length
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
stop_words_logits_processor = StopWordsLogitsProcessor(
stop_words_ids=stop_words_ids,
eos_token_id=generation_config.eos_token_id,
@ -541,7 +533,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
negative_prompt_attention_mask=None,
)
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
@ -552,9 +543,9 @@ class QWenLMHeadModel(QWenPreTrainedModel):
# 13. run sample
pad_token_id=generation_config.pad_token_id
eos_token_id_tensor=torch.tensor([generation_config.eos_token_id]).to(input_ids.device)
pad_token_id = generation_config.pad_token_id
eos_token_id_tensor = torch.tensor([generation_config.eos_token_id]).to(input_ids.device)
# init values
stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=StoppingCriteriaList()