Refine model of qwen.

This commit is contained in:
Colin 2024-01-07 17:50:58 +08:00
parent 3f8ea9db07
commit 82ac3e4863
1 changed files with 13 additions and 175 deletions

View File

@ -48,21 +48,7 @@ from qwen_generation_utils import (
logger = logging.get_logger(__name__)
QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"]
_ERROR_BAD_CHAT_FORMAT = """\
We detect you are probably using the pretrained model (rather than chat model) for chatting, since the chat_format in generation_config is not "chatml".
If you are directly using the model downloaded from Huggingface, please make sure you are using our "Qwen/Qwen-7B-Chat" Huggingface model (rather than "Qwen/Qwen-7B") when you call model.chat().
我们检测到您可能在使用预训练模型而非chat模型进行多轮chat因为您当前在generation_config指定的chat_format并未设置为我们在对话中所支持的"chatml"格式
如果您在直接使用我们从Huggingface提供的模型请确保您在调用model.chat()使用的是"Qwen/Qwen-7B-Chat"模型而非"Qwen/Qwen-7B"预训练模型
"""
_SENTINEL = object()
_ERROR_STREAM_IN_CHAT = """\
Pass argument `stream` to model.chat() is buggy, deprecated, and marked for removal. Please use model.chat_stream(...) instead of model.chat(..., stream=True).
向model.chat()传入参数stream的用法可能存在Bug该用法已被废弃将在未来被移除请使用model.chat_stream(...)代替model.chat(..., stream=True)
"""
class QWenAttention(nn.Module):
def __init__(self, config):
@ -209,7 +195,6 @@ class QWenAttention(nn.Module):
attn_output = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask
).transpose(1, 2)
attn_weight = None
context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim)
@ -339,13 +324,7 @@ class QWenPreTrainedModel(PreTrainedModel):
),
)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, QWenModel):
module.gradient_checkpointing = value
class QWenModel(QWenPreTrainedModel):
_keys_to_ignore_on_load_missing = ["attn.masked_bias"]
def __init__(self, config):
super().__init__(config)
@ -353,7 +332,6 @@ class QWenModel(QWenPreTrainedModel):
self.num_hidden_layers = config.num_hidden_layers
self.embed_dim = config.hidden_size
self.gradient_checkpointing = False
self.use_dynamic_ntk = config.use_dynamic_ntk
self.seq_length = config.seq_length
@ -381,12 +359,6 @@ class QWenModel(QWenPreTrainedModel):
self.post_init()
def get_input_embeddings(self):
return self.wte
def set_input_embeddings(self, new_embeddings):
self.wte = new_embeddings
def get_ntk_alpha(self, true_seq_len):
context_value = math.log(true_seq_len / self.seq_length, 2) + 1
ntk_alpha = 2 ** math.ceil(context_value) - 1
@ -398,8 +370,6 @@ class QWenModel(QWenPreTrainedModel):
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
@ -434,26 +404,8 @@ class QWenModel(QWenPreTrainedModel):
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(
past_length,
input_shape[-1] + past_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
if attention_mask is not None:
if batch_size <= 0:
@ -504,13 +456,6 @@ class QWenModel(QWenPreTrainedModel):
hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
@ -518,37 +463,17 @@ class QWenModel(QWenPreTrainedModel):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
rotary_pos_emb_list,
None,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
rotary_pos_emb_list=rotary_pos_emb_list,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
outputs = block(
hidden_states,
layer_past=layer_past,
rotary_pos_emb_list=rotary_pos_emb_list,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache is True:
@ -574,8 +499,6 @@ class QWenModel(QWenPreTrainedModel):
class QWenLMHeadModel(QWenPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.rotary_emb\.inv_freq"]
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias"]
def __init__(self, config):
super().__init__(config)
@ -584,12 +507,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
):
@ -620,8 +537,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
@ -637,8 +552,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
@ -694,8 +607,8 @@ class QWenLMHeadModel(QWenPreTrainedModel):
else self.generation_config
)
assert stream is _SENTINEL, _ERROR_STREAM_IN_CHAT
assert generation_config.chat_format == "chatml", _ERROR_BAD_CHAT_FORMAT
assert stream is _SENTINEL
assert generation_config.chat_format == "chatml"
if history is None:
history = []
else:
@ -746,81 +659,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
return response, history
def chat_stream(
self,
tokenizer: PreTrainedTokenizer,
query: str,
history: Optional[HistoryType],
system: str = "You are a helpful assistant.",
stop_words_ids: Optional[List[List[int]]] = None,
logits_processor: Optional[LogitsProcessorList] = None,
generation_config: Optional[GenerationConfig] = None,
**kwargs,
) -> Generator[str, Any, None]:
generation_config = (
generation_config
if generation_config is not None
else self.generation_config
)
assert generation_config.chat_format == "chatml", _ERROR_BAD_CHAT_FORMAT
if history is None:
history = []
if stop_words_ids is None:
stop_words_ids = []
max_window_size = kwargs.get("max_window_size", None)
if max_window_size is None:
max_window_size = generation_config.max_window_size
raw_text, context_tokens = make_context(
tokenizer,
query,
history=history,
system=system,
max_window_size=max_window_size,
chat_format=generation_config.chat_format,
)
stop_words_ids.extend(
get_stop_words_ids(generation_config.chat_format, tokenizer)
)
if stop_words_ids is not None:
stop_words_logits_processor = StopWordsLogitsProcessor(
stop_words_ids=stop_words_ids,
eos_token_id=generation_config.eos_token_id,
)
if logits_processor is None:
logits_processor = LogitsProcessorList([stop_words_logits_processor])
else:
logits_processor.append(stop_words_logits_processor)
input_ids = torch.tensor([context_tokens]).to(self.device)
from transformers_stream_generator.main import (
NewGenerationMixin,
StreamGenerationConfig,
)
self.__class__.generate_stream = NewGenerationMixin.generate
self.__class__.sample_stream = NewGenerationMixin.sample_stream
stream_config = StreamGenerationConfig(
**generation_config.to_dict(), do_stream=True
)
def stream_generator():
outputs = []
for token in self.generate_stream(
input_ids,
generation_config=stream_config,
logits_processor=logits_processor,
seed=-1,
**kwargs,
):
outputs.append(token.item())
yield tokenizer.decode(
outputs, skip_special_tokens=True, errors="ignore"
)
return stream_generator()
def generate(
self,
inputs: Optional[torch.Tensor] = None,