Format model of qwen.
This commit is contained in:
parent
5cf6e8b013
commit
d13f7e6c57
|
@ -38,6 +38,7 @@ from qwen_generation_utils import (
|
|||
)
|
||||
|
||||
import sys
|
||||
|
||||
sys.path.append("..")
|
||||
from tools import show
|
||||
|
||||
|
@ -144,7 +145,6 @@ class QWenAttention(nn.Module):
|
|||
else:
|
||||
attention_mask = causal_mask
|
||||
|
||||
|
||||
# qk = query @ key.transpose(-2, -1)
|
||||
# qk = qk[0]
|
||||
# show.DumpTensorToImage(qk,"q_matmul_k_layer_"+str(self.index)+".png")
|
||||
|
@ -441,12 +441,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
if max_window_size is None:
|
||||
max_window_size = generation_config.max_window_size
|
||||
raw_text, context_tokens = make_context(
|
||||
tokenizer,
|
||||
query,
|
||||
query_assistant,
|
||||
history=history,
|
||||
system=system,
|
||||
max_window_size=max_window_size
|
||||
tokenizer, query, query_assistant, history=history, system=system, max_window_size=max_window_size
|
||||
)
|
||||
|
||||
stop_words_ids.extend([[tokenizer.im_end_id], [tokenizer.im_start_id]])
|
||||
|
@ -469,13 +464,12 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
def generate(
|
||||
self,
|
||||
inputs: Optional[torch.Tensor] = None,
|
||||
stop_words_ids = [],
|
||||
stop_words_ids=[],
|
||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||
**kwargs,
|
||||
) -> Union[GenerateOutput, torch.LongTensor]:
|
||||
generation_config = self.generation_config
|
||||
|
||||
|
||||
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
||||
self._validate_model_class()
|
||||
|
||||
|
@ -486,7 +480,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
|
||||
# 2. Set generation parameters if not already defined
|
||||
|
||||
|
||||
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
|
||||
if model_kwargs.get("attention_mask", None) is None:
|
||||
logger.warning(
|
||||
|
@ -524,7 +517,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
generation_config.max_length = generation_config.max_new_tokens + input_ids_length
|
||||
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
|
||||
|
||||
|
||||
stop_words_logits_processor = StopWordsLogitsProcessor(
|
||||
stop_words_ids=stop_words_ids,
|
||||
eos_token_id=generation_config.eos_token_id,
|
||||
|
@ -541,7 +533,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
negative_prompt_attention_mask=None,
|
||||
)
|
||||
|
||||
|
||||
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
|
||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||
input_ids=input_ids,
|
||||
|
@ -552,9 +543,9 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
|
||||
# 13. run sample
|
||||
|
||||
pad_token_id=generation_config.pad_token_id
|
||||
eos_token_id_tensor=torch.tensor([generation_config.eos_token_id]).to(input_ids.device)
|
||||
|
||||
pad_token_id = generation_config.pad_token_id
|
||||
eos_token_id_tensor = torch.tensor([generation_config.eos_token_id]).to(input_ids.device)
|
||||
|
||||
# init values
|
||||
stopping_criteria = self._get_stopping_criteria(
|
||||
generation_config=generation_config, stopping_criteria=StoppingCriteriaList()
|
||||
|
|
Loading…
Reference in New Issue