Compare commits

...

3 Commits

Author SHA1 Message Date
Colin 4d493014ba Refine model of qwen. 2024-01-20 20:20:18 +08:00
Colin 12dcbec718 PreTrainedModel to mm.Module 2024-01-20 20:06:59 +08:00
Colin 0458e7303c Remove attention_mask 2024-01-20 18:08:20 +08:00
4 changed files with 127 additions and 184 deletions

View File

@ -33,5 +33,14 @@
"use_dynamic_ntk": true,
"use_flash_attn": "auto",
"use_logn_attn": true,
"vocab_size": 151936
"vocab_size": 151936,
"chat_format": "chatml",
"eos_token_id": 151643,
"pad_token_id": 151643,
"max_window_size": 6144,
"max_new_tokens": 512,
"do_sample": true,
"top_k": 0,
"top_p": 0.8,
"repetition_penalty": 1.1
}

View File

@ -52,7 +52,7 @@ print(model)
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
model = model.from_pretrained(model_dir, config=config, device_map="auto", trust_remote_code=True)
model = model.from_pretrained(model_dir).cuda()
# model = model.eval()
model = model.train() # control by @torch.no_grad()
@ -80,7 +80,7 @@ print(decode_tokens)
# 第二轮对话
response, history, decode_tokens = model.chat(tokenizer, "给我讲一个年轻人奋斗创业最终取得成功的故事。", "", history=None)
print(response)
print(decode_tokens)
# <|im_start|>system

View File

@ -1,12 +0,0 @@
{
"chat_format": "chatml",
"eos_token_id": 151643,
"pad_token_id": 151643,
"max_window_size": 6144,
"max_new_tokens": 512,
"do_sample": true,
"top_k": 0,
"top_p": 0.8,
"repetition_penalty": 1.1,
"transformers_version": "4.31.0"
}

View File

@ -1,11 +1,10 @@
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import copy
import math
import inspect
import os
import gc
from tqdm import auto as tqdm_lib
import json
from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator
import torch
@ -37,6 +36,11 @@ from qwen_generation_utils import (
StopWordsLogitsProcessor,
)
from safetensors import safe_open
from safetensors.torch import load_file as safe_load_file
from safetensors.torch import save_file as safe_save_file
import sys
sys.path.append("..")
@ -96,7 +100,6 @@ class QWenAttention(nn.Module):
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
):
mixed_x_layer = self.c_attn(hidden_states)
query, key, value = mixed_x_layer.split(self.split_size, dim=2)
@ -120,32 +123,21 @@ class QWenAttention(nn.Module):
query = query * logn_tensor.expand_as(query)
key_size = key.size(1)
if query.size(1) == key_size:
causal_mask = torch.tril(torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)).view(
1, 1, key_size, key_size
)
else:
causal_mask = None
query = query.permute(0, 2, 1, 3)
key = key.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3)
if attention_mask is not None:
attention_mask = attention_mask.expand(-1, -1, causal_mask.size(2), -1)
if causal_mask is not None:
attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min)
else:
attention_mask = causal_mask
# qk = query @ key.transpose(-2, -1)
# qk = qk[0]
# show.DumpTensorToImage(qk,"q_matmul_k_layer_"+str(self.index)+".png")
# print("layer:" + str(self.index) + " query.shape:"+ str(query.shape))
# print("layer:" + str(self.index) + " key.shape:"+ str(key.shape))
# print("layer:" + str(self.index) + " value.shape:"+ str(value.shape))
# print("\n")
# prePath = "../generated/query_matmul_key/img/"
# show.DumpTensorToImage(
# qk, prePath + "q_matmul_k_sequence_" + str(key_size) + "_layer_" + str(self.index) + ".png"
# )
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask).transpose(1, 2)
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=causal_mask).transpose(1, 2)
context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(context_layer)
@ -189,15 +181,10 @@ class QWenBlock(nn.Module):
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
):
layernorm_output = self.ln_1(hidden_states)
attn_outputs = self.attn(
layernorm_output,
rotary_pos_emb_list,
attention_mask=attention_mask,
)
attn_outputs = self.attn(layernorm_output, rotary_pos_emb_list)
attn_output = attn_outputs[0]
residual = hidden_states
layernorm_input = attn_output + residual
@ -209,7 +196,7 @@ class QWenBlock(nn.Module):
return hidden_states
class QWenPreTrainedModel(PreTrainedModel):
class QWenPreTrainedModel(nn.Module):
config_class = QWenConfig
base_model_prefix = "transformer"
is_parallelizable = False
@ -217,7 +204,7 @@ class QWenPreTrainedModel(PreTrainedModel):
_no_split_modules = ["QWenBlock"]
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
super().__init__()
class QWenModel(QWenPreTrainedModel):
@ -248,8 +235,6 @@ class QWenModel(QWenPreTrainedModel):
eps=config.layer_norm_epsilon,
)
self.post_init()
def get_ntk_alpha(self, true_seq_len):
context_value = math.log(true_seq_len / self.seq_length, 2) + 1
ntk_alpha = 2 ** math.ceil(context_value) - 1
@ -259,8 +244,6 @@ class QWenModel(QWenPreTrainedModel):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
):
if input_ids is not None and inputs_embeds is not None:
@ -275,14 +258,6 @@ class QWenModel(QWenPreTrainedModel):
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is not None:
attention_mask = attention_mask.view(batch_size, -1)
attention_mask = attention_mask[:, None, None, :]
attention_mask = attention_mask.to(dtype=self.dtype)
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
hidden_states = inputs_embeds
@ -295,13 +270,6 @@ class QWenModel(QWenPreTrainedModel):
ntk_alpha_list = self.rotary_emb._ntk_alpha_cached_list
else:
ntk_alpha_list = []
if attention_mask is not None and kv_seq_len > self.seq_length:
true_seq_lens = attention_mask.squeeze(1).squeeze(1).eq(0).sum(dim=-1, dtype=torch.int32)
for i in range(hidden_states.size()[0]):
true_seq_len = true_seq_lens[i].item()
ntk_alpha = self.get_ntk_alpha(true_seq_len)
ntk_alpha_list.append(ntk_alpha)
else:
ntk_alpha = self.get_ntk_alpha(kv_seq_len)
ntk_alpha_list.append(ntk_alpha)
self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list
@ -312,52 +280,34 @@ class QWenModel(QWenPreTrainedModel):
all_hidden_states = None
for block in self.h:
hidden_states = block(
hidden_states,
rotary_pos_emb_list=rotary_pos_emb_list,
attention_mask=attention_mask,
)
hidden_states = block(hidden_states, rotary_pos_emb_list=rotary_pos_emb_list)
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
return BaseModelOutputWithPast(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
class QWenLMHeadModel(QWenPreTrainedModel):
class QWenLMHeadModel(nn.Module):
def __init__(self, config):
super().__init__(config)
super().__init__()
self.config = config
self.transformer = QWenModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
self.generation_config = GenerationConfig.from_model_config(config)
def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, **kwargs):
if input_ids.size(0) == 1:
attention_mask = None
else:
attention_mask = kwargs.get("attention_mask", None)
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"attention_mask": attention_mask,
}
)
return model_inputs
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
transformer_outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
)
hidden_states = transformer_outputs[0]
@ -387,6 +337,59 @@ class QWenLMHeadModel(QWenPreTrainedModel):
attentions=transformer_outputs.attentions,
)
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]]):
load_in_8bit = False
load_in_4bit = False
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
resolved_archive_file = os.path.join(pretrained_model_name_or_path, "model.safetensors.index.json")
print(f"loading weights file {resolved_archive_file}")
with open(resolved_archive_file, "r") as f:
index = json.loads(f.read())
shard_filenames = sorted(set(index["weight_map"].values()))
resolved_archive_file = [os.path.join(pretrained_model_name_or_path, f) for f in shard_filenames]
model = cls._load_pretrained_model(resolved_archive_file)
model.is_loaded_in_4bit = load_in_4bit
model.is_loaded_in_8bit = load_in_8bit
return model
def _load_state_dict_into_model(self, model_to_load, state_dict, start_prefix):
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
error_msgs = []
def load(module: nn.Module, state_dict, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
module._load_from_state_dict(*args)
for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".")
load(model_to_load, state_dict, prefix=start_prefix)
del state_dict
return error_msgs
def _load_pretrained_model(cls, resolved_archive_file):
start_prefix = ""
model_to_load = cls
error_msgs = []
if len(resolved_archive_file) > 1:
resolved_archive_file = tqdm_lib.tqdm(resolved_archive_file, desc="Loading checkpoint shards")
for shard_file in resolved_archive_file:
state_dict = safe_load_file(shard_file)
error_msgs += cls._load_state_dict_into_model(model_to_load, state_dict, start_prefix)
del state_dict # force memory release
gc.collect()
print(f"All model checkpoint weights were used when initializing {cls.__class__.__name__}.\n")
return cls
@torch.no_grad()
def chat(
self,
@ -406,18 +409,14 @@ class QWenLMHeadModel(QWenPreTrainedModel):
stop_words_ids = []
max_window_size = kwargs.get("max_window_size", None)
if max_window_size is None:
max_window_size = generation_config.max_window_size
raw_text, context_tokens = make_context(
tokenizer, query, query_assistant, history=history, system=system, max_window_size=max_window_size
)
raw_text, context_tokens = make_context(tokenizer, query, query_assistant, history=history, system=system)
stop_words_ids.extend([[tokenizer.im_end_id], [tokenizer.im_start_id]])
input_ids = torch.tensor([context_tokens]).to(self.device)
input_ids = torch.tensor([context_tokens]).to(next(self.parameters()).device)
outputs = self.generate(
input_ids,
stop_words_ids=stop_words_ids,
tokenizer=tokenizer,
**kwargs,
)
decoded, response, end_reason = decode_tokens(
@ -432,110 +431,53 @@ class QWenLMHeadModel(QWenPreTrainedModel):
def generate(
self,
inputs: Optional[torch.Tensor] = None,
input_ids: Optional[torch.Tensor] = None,
stop_words_ids=[],
tokenizer=None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
generation_config = self.generation_config
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self._validate_model_class()
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy())
# 2. Set generation parameters if not already defined
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
generation_config.pad_token_id = eos_token_id
# 3. Define model inputs
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
inputs, generation_config.bos_token_id, model_kwargs
)
# 4. Define other model kwargs
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
requires_attention_mask = "encoder_outputs" not in model_kwargs
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor,
generation_config.pad_token_id,
generation_config.eos_token_id,
)
# 5. Prepare `input_ids` which will be used for auto-regressive generation
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
generation_config.max_length = generation_config.max_new_tokens + input_ids_length
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
stop_words_logits_processor = StopWordsLogitsProcessor(
stop_words_ids=stop_words_ids,
eos_token_id=generation_config.eos_token_id,
)
logits_processor = LogitsProcessorList([stop_words_logits_processor])
logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_length,
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor,
model_kwargs=model_kwargs,
negative_prompt_ids=None,
negative_prompt_attention_mask=None,
)
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_return_sequences,
is_encoder_decoder=False,
**model_kwargs,
)
# 13. run sample
pad_token_id = generation_config.pad_token_id
eos_token_id_tensor = torch.tensor([generation_config.eos_token_id]).to(input_ids.device)
# init values
stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=StoppingCriteriaList()
)
logits_warper = self._get_logits_warper(generation_config)
# init attention / hidden states / scores tuples
scores = None
# keep track of which sequences are already finished
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
this_peer_finished = False
# auto-regressive generation
while True:
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token
outputs = self(**model_inputs)
next_token_scores = outputs.logits[:, -1, :]
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_scores)
next_token_scores = logits_warper(input_ids, next_token_scores)
# repetition_penalty
penalty = self.config.repetition_penalty
score = torch.gather(next_token_scores, 1, input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
score = torch.where(score < 0, score * penalty, score / penalty)
next_token_scores = next_token_scores.scatter_(1, input_ids, score)
# top_p
top_p = self.config.top_p
filter_value = -float("Inf")
min_tokens_to_keep = 1
sorted_logits, sorted_indices = torch.sort(next_token_scores, descending=False)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
next_token_scores = next_token_scores.masked_fill(indices_to_remove, filter_value)
# sample
probs = nn.functional.softmax(next_token_scores, dim=-1)
@ -545,20 +487,24 @@ class QWenLMHeadModel(QWenPreTrainedModel):
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False)
unfinished_sequences = unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
)
# decoded, response, end_reason = decode_tokens(
# next_tokens,
# tokenizer,
# raw_text_len=0,
# context_length=0,
# errors="replace",
# )
# print(decoded)
# stop when each sentence is finished
if unfinished_sequences.max() == 0:
this_peer_finished = True
# stop if we exceed the maximum length
if stopping_criteria(input_ids, scores):
this_peer_finished = True
if this_peer_finished:
break
return input_ids