Format model of qwen.
This commit is contained in:
parent
5cf6e8b013
commit
d13f7e6c57
|
@ -38,6 +38,7 @@ from qwen_generation_utils import (
|
||||||
)
|
)
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
sys.path.append("..")
|
sys.path.append("..")
|
||||||
from tools import show
|
from tools import show
|
||||||
|
|
||||||
|
@ -144,7 +145,6 @@ class QWenAttention(nn.Module):
|
||||||
else:
|
else:
|
||||||
attention_mask = causal_mask
|
attention_mask = causal_mask
|
||||||
|
|
||||||
|
|
||||||
# qk = query @ key.transpose(-2, -1)
|
# qk = query @ key.transpose(-2, -1)
|
||||||
# qk = qk[0]
|
# qk = qk[0]
|
||||||
# show.DumpTensorToImage(qk,"q_matmul_k_layer_"+str(self.index)+".png")
|
# show.DumpTensorToImage(qk,"q_matmul_k_layer_"+str(self.index)+".png")
|
||||||
|
@ -441,12 +441,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
if max_window_size is None:
|
if max_window_size is None:
|
||||||
max_window_size = generation_config.max_window_size
|
max_window_size = generation_config.max_window_size
|
||||||
raw_text, context_tokens = make_context(
|
raw_text, context_tokens = make_context(
|
||||||
tokenizer,
|
tokenizer, query, query_assistant, history=history, system=system, max_window_size=max_window_size
|
||||||
query,
|
|
||||||
query_assistant,
|
|
||||||
history=history,
|
|
||||||
system=system,
|
|
||||||
max_window_size=max_window_size
|
|
||||||
)
|
)
|
||||||
|
|
||||||
stop_words_ids.extend([[tokenizer.im_end_id], [tokenizer.im_start_id]])
|
stop_words_ids.extend([[tokenizer.im_end_id], [tokenizer.im_start_id]])
|
||||||
|
@ -469,13 +464,12 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
inputs: Optional[torch.Tensor] = None,
|
inputs: Optional[torch.Tensor] = None,
|
||||||
stop_words_ids = [],
|
stop_words_ids=[],
|
||||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[GenerateOutput, torch.LongTensor]:
|
) -> Union[GenerateOutput, torch.LongTensor]:
|
||||||
generation_config = self.generation_config
|
generation_config = self.generation_config
|
||||||
|
|
||||||
|
|
||||||
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
||||||
self._validate_model_class()
|
self._validate_model_class()
|
||||||
|
|
||||||
|
@ -486,7 +480,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
|
|
||||||
# 2. Set generation parameters if not already defined
|
# 2. Set generation parameters if not already defined
|
||||||
|
|
||||||
|
|
||||||
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
|
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
|
||||||
if model_kwargs.get("attention_mask", None) is None:
|
if model_kwargs.get("attention_mask", None) is None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
@ -524,7 +517,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_length
|
generation_config.max_length = generation_config.max_new_tokens + input_ids_length
|
||||||
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
|
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
|
||||||
|
|
||||||
|
|
||||||
stop_words_logits_processor = StopWordsLogitsProcessor(
|
stop_words_logits_processor = StopWordsLogitsProcessor(
|
||||||
stop_words_ids=stop_words_ids,
|
stop_words_ids=stop_words_ids,
|
||||||
eos_token_id=generation_config.eos_token_id,
|
eos_token_id=generation_config.eos_token_id,
|
||||||
|
@ -541,7 +533,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
negative_prompt_attention_mask=None,
|
negative_prompt_attention_mask=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
|
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
|
||||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
|
@ -552,9 +543,9 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
|
|
||||||
# 13. run sample
|
# 13. run sample
|
||||||
|
|
||||||
pad_token_id=generation_config.pad_token_id
|
pad_token_id = generation_config.pad_token_id
|
||||||
eos_token_id_tensor=torch.tensor([generation_config.eos_token_id]).to(input_ids.device)
|
eos_token_id_tensor = torch.tensor([generation_config.eos_token_id]).to(input_ids.device)
|
||||||
|
|
||||||
# init values
|
# init values
|
||||||
stopping_criteria = self._get_stopping_criteria(
|
stopping_criteria = self._get_stopping_criteria(
|
||||||
generation_config=generation_config, stopping_criteria=StoppingCriteriaList()
|
generation_config=generation_config, stopping_criteria=StoppingCriteriaList()
|
||||||
|
|
Loading…
Reference in New Issue