Refine model of qwen for long sequence in eval.
This commit is contained in:
parent
45c2f532ff
commit
f96bcc799c
15
qwen/demo.py
15
qwen/demo.py
|
@ -52,11 +52,10 @@ print(model)
|
|||
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
||||
model = model.from_pretrained(
|
||||
model_dir, config=config, device_map="auto", trust_remote_code=True
|
||||
).train()
|
||||
# model.train()
|
||||
# model.zero_grad()
|
||||
model = model.from_pretrained(model_dir, config=config, device_map="auto", trust_remote_code=True)
|
||||
|
||||
# model = model.eval()
|
||||
model = model.train() # control by @torch.no_grad()
|
||||
|
||||
# 可指定不同的生成长度、top_p等相关超参
|
||||
# model.generation_config = GenerationConfig.from_pretrained(
|
||||
|
@ -74,16 +73,14 @@ print(decode_tokens)
|
|||
# 日本的首都东京。<|im_end|><|endoftext|>
|
||||
|
||||
|
||||
|
||||
|
||||
# # 第一轮对话
|
||||
# response, history, decode_tokens = model.chat(tokenizer, "你好", "", history=None)
|
||||
# print(decode_tokens)
|
||||
# # 你好!很高兴为你提供帮助。
|
||||
|
||||
# 第二轮对话
|
||||
# response, history = model.chat(tokenizer, "给我讲一个年轻人奋斗创业最终取得成功的故事。", history=None)
|
||||
# print(response)
|
||||
response, history, decode_tokens = model.chat(tokenizer, "给我讲一个年轻人奋斗创业最终取得成功的故事。", "", history=None)
|
||||
print(response)
|
||||
|
||||
|
||||
# <|im_start|>system
|
||||
|
|
|
@ -41,8 +41,10 @@ import sys
|
|||
|
||||
sys.path.append("..")
|
||||
from tools import show
|
||||
from tools import mem_tracker
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
# tracker = mem_tracker.MemTracker()
|
||||
# tracker.track()
|
||||
|
||||
|
||||
class QWenAttention(nn.Module):
|
||||
|
@ -110,8 +112,6 @@ class QWenAttention(nn.Module):
|
|||
query = apply_rotary_pos_emb(query, q_pos_emb)
|
||||
key = apply_rotary_pos_emb(key, k_pos_emb)
|
||||
|
||||
present = (key, value)
|
||||
|
||||
key_size = key.size(1)
|
||||
if key_size > self.seq_length and not self.training:
|
||||
seq_start = key.size(1) - query.size(1)
|
||||
|
@ -148,8 +148,8 @@ class QWenAttention(nn.Module):
|
|||
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask).transpose(1, 2)
|
||||
context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
||||
attn_output = self.c_proj(context_layer)
|
||||
outputs = (attn_output, present)
|
||||
return outputs
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class QWenMLP(nn.Module):
|
||||
|
@ -199,7 +199,6 @@ class QWenBlock(nn.Module):
|
|||
attention_mask=attention_mask,
|
||||
)
|
||||
attn_output = attn_outputs[0]
|
||||
outputs = attn_outputs[1:]
|
||||
residual = hidden_states
|
||||
layernorm_input = attn_output + residual
|
||||
|
||||
|
@ -207,8 +206,7 @@ class QWenBlock(nn.Module):
|
|||
residual = layernorm_input
|
||||
mlp_output = self.mlp(layernorm_output)
|
||||
hidden_states = residual + mlp_output
|
||||
outputs = (hidden_states,) + outputs
|
||||
return outputs
|
||||
return hidden_states
|
||||
|
||||
|
||||
class QWenPreTrainedModel(PreTrainedModel):
|
||||
|
@ -312,16 +310,13 @@ class QWenModel(QWenPreTrainedModel):
|
|||
hidden_states = self.drop(hidden_states)
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
|
||||
presents = ()
|
||||
all_hidden_states = None
|
||||
for i, block in enumerate(self.h):
|
||||
outputs = block(
|
||||
for block in self.h:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
rotary_pos_emb_list=rotary_pos_emb_list,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
hidden_states = outputs[0]
|
||||
presents = presents + (outputs[1],)
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
hidden_states = hidden_states.view(output_shape)
|
||||
|
@ -392,6 +387,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def chat(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
|
@ -454,15 +450,9 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
# 2. Set generation parameters if not already defined
|
||||
|
||||
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
|
||||
if model_kwargs.get("attention_mask", None) is None:
|
||||
logger.warning(
|
||||
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
||||
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
|
||||
)
|
||||
eos_token_id = generation_config.eos_token_id
|
||||
if isinstance(eos_token_id, list):
|
||||
eos_token_id = eos_token_id[0]
|
||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
||||
generation_config.pad_token_id = eos_token_id
|
||||
|
||||
# 3. Define model inputs
|
||||
|
@ -571,7 +561,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
|
||||
if this_peer_finished:
|
||||
break
|
||||
|
||||
return input_ids
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue