Compare commits
2 Commits
a8f2fbbff5
...
82ac3e4863
Author | SHA1 | Date |
---|---|---|
Colin | 82ac3e4863 | |
Colin | 3f8ea9db07 |
|
@ -48,21 +48,7 @@ from qwen_generation_utils import (
|
|||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"]
|
||||
|
||||
_ERROR_BAD_CHAT_FORMAT = """\
|
||||
We detect you are probably using the pretrained model (rather than chat model) for chatting, since the chat_format in generation_config is not "chatml".
|
||||
If you are directly using the model downloaded from Huggingface, please make sure you are using our "Qwen/Qwen-7B-Chat" Huggingface model (rather than "Qwen/Qwen-7B") when you call model.chat().
|
||||
我们检测到您可能在使用预训练模型(而非chat模型)进行多轮chat,因为您当前在generation_config指定的chat_format,并未设置为我们在对话中所支持的"chatml"格式。
|
||||
如果您在直接使用我们从Huggingface提供的模型,请确保您在调用model.chat()时,使用的是"Qwen/Qwen-7B-Chat"模型(而非"Qwen/Qwen-7B"预训练模型)。
|
||||
"""
|
||||
|
||||
_SENTINEL = object()
|
||||
_ERROR_STREAM_IN_CHAT = """\
|
||||
Pass argument `stream` to model.chat() is buggy, deprecated, and marked for removal. Please use model.chat_stream(...) instead of model.chat(..., stream=True).
|
||||
向model.chat()传入参数stream的用法可能存在Bug,该用法已被废弃,将在未来被移除。请使用model.chat_stream(...)代替model.chat(..., stream=True)。
|
||||
"""
|
||||
|
||||
|
||||
class QWenAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
|
@ -209,7 +195,6 @@ class QWenAttention(nn.Module):
|
|||
attn_output = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask
|
||||
).transpose(1, 2)
|
||||
attn_weight = None
|
||||
|
||||
context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
||||
|
||||
|
@ -339,13 +324,7 @@ class QWenPreTrainedModel(PreTrainedModel):
|
|||
),
|
||||
)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, QWenModel):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
|
||||
class QWenModel(QWenPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["attn.masked_bias"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
@ -353,7 +332,6 @@ class QWenModel(QWenPreTrainedModel):
|
|||
self.num_hidden_layers = config.num_hidden_layers
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.use_dynamic_ntk = config.use_dynamic_ntk
|
||||
self.seq_length = config.seq_length
|
||||
|
||||
|
@ -381,12 +359,6 @@ class QWenModel(QWenPreTrainedModel):
|
|||
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.wte
|
||||
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.wte = new_embeddings
|
||||
|
||||
def get_ntk_alpha(self, true_seq_len):
|
||||
context_value = math.log(true_seq_len / self.seq_length, 2) + 1
|
||||
ntk_alpha = 2 ** math.ceil(context_value) - 1
|
||||
|
@ -398,8 +370,6 @@ class QWenModel(QWenPreTrainedModel):
|
|||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
|
@ -434,26 +404,8 @@ class QWenModel(QWenPreTrainedModel):
|
|||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids.view(-1, input_shape[-1])
|
||||
|
||||
if past_key_values is None:
|
||||
past_length = 0
|
||||
past_key_values = tuple([None] * len(self.h))
|
||||
else:
|
||||
past_length = past_key_values[0][0].size(-2)
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(
|
||||
past_length,
|
||||
input_shape[-1] + past_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||
|
||||
if attention_mask is not None:
|
||||
if batch_size <= 0:
|
||||
|
@ -504,13 +456,6 @@ class QWenModel(QWenPreTrainedModel):
|
|||
hidden_states = self.drop(hidden_states)
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
presents = () if use_cache else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
|
@ -518,37 +463,17 @@ class QWenModel(QWenPreTrainedModel):
|
|||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, use_cache, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
rotary_pos_emb_list,
|
||||
None,
|
||||
attention_mask,
|
||||
head_mask[i],
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
rotary_pos_emb_list=rotary_pos_emb_list,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask[i],
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
rotary_pos_emb_list=rotary_pos_emb_list,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask[i],
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if use_cache is True:
|
||||
|
@ -574,8 +499,6 @@ class QWenModel(QWenPreTrainedModel):
|
|||
|
||||
|
||||
class QWenLMHeadModel(QWenPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.rotary_emb\.inv_freq"]
|
||||
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
@ -584,12 +507,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.post_init()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
|
||||
):
|
||||
|
@ -620,8 +537,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
|
@ -637,8 +552,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
|
@ -694,8 +607,8 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
else self.generation_config
|
||||
)
|
||||
|
||||
assert stream is _SENTINEL, _ERROR_STREAM_IN_CHAT
|
||||
assert generation_config.chat_format == "chatml", _ERROR_BAD_CHAT_FORMAT
|
||||
assert stream is _SENTINEL
|
||||
assert generation_config.chat_format == "chatml"
|
||||
if history is None:
|
||||
history = []
|
||||
else:
|
||||
|
@ -724,7 +637,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
outputs = self.generate(
|
||||
input_ids,
|
||||
stop_words_ids=stop_words_ids,
|
||||
return_dict_in_generate=False,
|
||||
generation_config=generation_config,
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -747,82 +659,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
|
||||
return response, history
|
||||
|
||||
def chat_stream(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
query: str,
|
||||
history: Optional[HistoryType],
|
||||
system: str = "You are a helpful assistant.",
|
||||
stop_words_ids: Optional[List[List[int]]] = None,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
**kwargs,
|
||||
) -> Generator[str, Any, None]:
|
||||
generation_config = (
|
||||
generation_config
|
||||
if generation_config is not None
|
||||
else self.generation_config
|
||||
)
|
||||
assert generation_config.chat_format == "chatml", _ERROR_BAD_CHAT_FORMAT
|
||||
if history is None:
|
||||
history = []
|
||||
if stop_words_ids is None:
|
||||
stop_words_ids = []
|
||||
|
||||
max_window_size = kwargs.get("max_window_size", None)
|
||||
if max_window_size is None:
|
||||
max_window_size = generation_config.max_window_size
|
||||
raw_text, context_tokens = make_context(
|
||||
tokenizer,
|
||||
query,
|
||||
history=history,
|
||||
system=system,
|
||||
max_window_size=max_window_size,
|
||||
chat_format=generation_config.chat_format,
|
||||
)
|
||||
|
||||
stop_words_ids.extend(
|
||||
get_stop_words_ids(generation_config.chat_format, tokenizer)
|
||||
)
|
||||
if stop_words_ids is not None:
|
||||
stop_words_logits_processor = StopWordsLogitsProcessor(
|
||||
stop_words_ids=stop_words_ids,
|
||||
eos_token_id=generation_config.eos_token_id,
|
||||
)
|
||||
if logits_processor is None:
|
||||
logits_processor = LogitsProcessorList([stop_words_logits_processor])
|
||||
else:
|
||||
logits_processor.append(stop_words_logits_processor)
|
||||
input_ids = torch.tensor([context_tokens]).to(self.device)
|
||||
|
||||
from transformers_stream_generator.main import (
|
||||
NewGenerationMixin,
|
||||
StreamGenerationConfig,
|
||||
)
|
||||
|
||||
self.__class__.generate_stream = NewGenerationMixin.generate
|
||||
self.__class__.sample_stream = NewGenerationMixin.sample_stream
|
||||
stream_config = StreamGenerationConfig(
|
||||
**generation_config.to_dict(), do_stream=True
|
||||
)
|
||||
|
||||
def stream_generator():
|
||||
outputs = []
|
||||
for token in self.generate_stream(
|
||||
input_ids,
|
||||
return_dict_in_generate=False,
|
||||
generation_config=stream_config,
|
||||
logits_processor=logits_processor,
|
||||
seed=-1,
|
||||
**kwargs,
|
||||
):
|
||||
outputs.append(token.item())
|
||||
yield tokenizer.decode(
|
||||
outputs, skip_special_tokens=True, errors="ignore"
|
||||
)
|
||||
|
||||
return stream_generator()
|
||||
|
||||
def generate(
|
||||
self,
|
||||
inputs: Optional[torch.Tensor] = None,
|
||||
|
@ -1074,7 +910,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
pad_token_id=generation_config.pad_token_id,
|
||||
eos_token_id=generation_config.eos_token_id,
|
||||
output_scores=generation_config.output_scores,
|
||||
return_dict_in_generate=generation_config.return_dict_in_generate,
|
||||
synced_gpus=synced_gpus,
|
||||
streamer=streamer,
|
||||
**model_kwargs,
|
||||
|
@ -1092,7 +927,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_scores: Optional[bool] = None,
|
||||
return_dict_in_generate: Optional[bool] = None,
|
||||
synced_gpus: bool = False,
|
||||
streamer: Optional["BaseStreamer"] = None,
|
||||
**model_kwargs,
|
||||
|
@ -1148,23 +982,9 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
if output_hidden_states is not None
|
||||
else self.generation_config.output_hidden_states
|
||||
)
|
||||
return_dict_in_generate = (
|
||||
return_dict_in_generate
|
||||
if return_dict_in_generate is not None
|
||||
else self.generation_config.return_dict_in_generate
|
||||
)
|
||||
|
||||
# init attention / hidden states / scores tuples
|
||||
scores = () if (return_dict_in_generate and output_scores) else None
|
||||
decoder_attentions = (
|
||||
() if (return_dict_in_generate and output_attentions) else None
|
||||
)
|
||||
cross_attentions = (
|
||||
() if (return_dict_in_generate and output_attentions) else None
|
||||
)
|
||||
decoder_hidden_states = (
|
||||
() if (return_dict_in_generate and output_hidden_states) else None
|
||||
)
|
||||
scores = None
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
unfinished_sequences = torch.ones(
|
||||
|
@ -1190,16 +1010,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
next_token_scores = logits_processor(input_ids, next_token_logits)
|
||||
next_token_scores = logits_warper(input_ids, next_token_scores)
|
||||
|
||||
# Store scores, attentions and hidden_states when required
|
||||
if return_dict_in_generate:
|
||||
if output_scores:
|
||||
scores += (next_token_scores,)
|
||||
if output_attentions:
|
||||
decoder_attentions += (outputs.attentions,)
|
||||
|
||||
if output_hidden_states:
|
||||
decoder_hidden_states += (outputs.hidden_states,)
|
||||
|
||||
# sample
|
||||
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
||||
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||
|
|
Loading…
Reference in New Issue