Refine model of qwen for long sequence in eval.

This commit is contained in:
Colin 2024-01-19 14:54:48 +08:00
parent 45c2f532ff
commit f96bcc799c
2 changed files with 16 additions and 30 deletions

View File

@ -52,11 +52,10 @@ print(model)
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
model = model.from_pretrained( model = model.from_pretrained(model_dir, config=config, device_map="auto", trust_remote_code=True)
model_dir, config=config, device_map="auto", trust_remote_code=True
).train() # model = model.eval()
# model.train() model = model.train() # control by @torch.no_grad()
# model.zero_grad()
# 可指定不同的生成长度、top_p等相关超参 # 可指定不同的生成长度、top_p等相关超参
# model.generation_config = GenerationConfig.from_pretrained( # model.generation_config = GenerationConfig.from_pretrained(
@ -74,16 +73,14 @@ print(decode_tokens)
# 日本的首都东京。<|im_end|><|endoftext|> # 日本的首都东京。<|im_end|><|endoftext|>
# # 第一轮对话 # # 第一轮对话
# response, history, decode_tokens = model.chat(tokenizer, "你好", "", history=None) # response, history, decode_tokens = model.chat(tokenizer, "你好", "", history=None)
# print(decode_tokens) # print(decode_tokens)
# # 你好!很高兴为你提供帮助。 # # 你好!很高兴为你提供帮助。
# 第二轮对话 # 第二轮对话
# response, history = model.chat(tokenizer, "给我讲一个年轻人奋斗创业最终取得成功的故事。", history=None) response, history, decode_tokens = model.chat(tokenizer, "给我讲一个年轻人奋斗创业最终取得成功的故事。", "", history=None)
# print(response) print(response)
# <|im_start|>system # <|im_start|>system
@ -93,4 +90,4 @@ print(decode_tokens)
# <|im_start|>assistant # <|im_start|>assistant
# 莎士比亚是头一个使用“你好”这个词的文学家,他在《哈姆雷特》中写道:“你是谁?你在哪儿? # 莎士比亚是头一个使用“你好”这个词的文学家,他在《哈姆雷特》中写道:“你是谁?你在哪儿?
# ”他的这一段话,通常被认为是最早的使用“你好”这个词的文学记载。这句话在英国语中非常常见, # ”他的这一段话,通常被认为是最早的使用“你好”这个词的文学记载。这句话在英国语中非常常见,
# 特别是在正式或礼貌的情况下。<|im_end|><|endoftext|> # 特别是在正式或礼貌的情况下。<|im_end|><|endoftext|>

View File

@ -41,8 +41,10 @@ import sys
sys.path.append("..") sys.path.append("..")
from tools import show from tools import show
from tools import mem_tracker
logger = logging.get_logger(__name__) # tracker = mem_tracker.MemTracker()
# tracker.track()
class QWenAttention(nn.Module): class QWenAttention(nn.Module):
@ -110,8 +112,6 @@ class QWenAttention(nn.Module):
query = apply_rotary_pos_emb(query, q_pos_emb) query = apply_rotary_pos_emb(query, q_pos_emb)
key = apply_rotary_pos_emb(key, k_pos_emb) key = apply_rotary_pos_emb(key, k_pos_emb)
present = (key, value)
key_size = key.size(1) key_size = key.size(1)
if key_size > self.seq_length and not self.training: if key_size > self.seq_length and not self.training:
seq_start = key.size(1) - query.size(1) seq_start = key.size(1) - query.size(1)
@ -148,8 +148,8 @@ class QWenAttention(nn.Module):
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask).transpose(1, 2) attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask).transpose(1, 2)
context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim) context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(context_layer) attn_output = self.c_proj(context_layer)
outputs = (attn_output, present)
return outputs return attn_output
class QWenMLP(nn.Module): class QWenMLP(nn.Module):
@ -199,7 +199,6 @@ class QWenBlock(nn.Module):
attention_mask=attention_mask, attention_mask=attention_mask,
) )
attn_output = attn_outputs[0] attn_output = attn_outputs[0]
outputs = attn_outputs[1:]
residual = hidden_states residual = hidden_states
layernorm_input = attn_output + residual layernorm_input = attn_output + residual
@ -207,8 +206,7 @@ class QWenBlock(nn.Module):
residual = layernorm_input residual = layernorm_input
mlp_output = self.mlp(layernorm_output) mlp_output = self.mlp(layernorm_output)
hidden_states = residual + mlp_output hidden_states = residual + mlp_output
outputs = (hidden_states,) + outputs return hidden_states
return outputs
class QWenPreTrainedModel(PreTrainedModel): class QWenPreTrainedModel(PreTrainedModel):
@ -312,16 +310,13 @@ class QWenModel(QWenPreTrainedModel):
hidden_states = self.drop(hidden_states) hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),) output_shape = input_shape + (hidden_states.size(-1),)
presents = ()
all_hidden_states = None all_hidden_states = None
for i, block in enumerate(self.h): for block in self.h:
outputs = block( hidden_states = block(
hidden_states, hidden_states,
rotary_pos_emb_list=rotary_pos_emb_list, rotary_pos_emb_list=rotary_pos_emb_list,
attention_mask=attention_mask, attention_mask=attention_mask,
) )
hidden_states = outputs[0]
presents = presents + (outputs[1],)
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape) hidden_states = hidden_states.view(output_shape)
@ -392,6 +387,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
) )
@torch.no_grad()
def chat( def chat(
self, self,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
@ -454,15 +450,9 @@ class QWenLMHeadModel(QWenPreTrainedModel):
# 2. Set generation parameters if not already defined # 2. Set generation parameters if not already defined
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
if model_kwargs.get("attention_mask", None) is None:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
eos_token_id = generation_config.eos_token_id eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list): if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0] eos_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id generation_config.pad_token_id = eos_token_id
# 3. Define model inputs # 3. Define model inputs
@ -571,7 +561,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
if this_peer_finished: if this_peer_finished:
break break
return input_ids return input_ids