Refine model of qwen.

This commit is contained in:
Colin 2024-01-20 20:47:26 +08:00
parent 4d493014ba
commit 04f9fe002f
2 changed files with 9 additions and 41 deletions

View File

@ -54,8 +54,8 @@ print(model)
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
model = model.from_pretrained(model_dir).cuda() model = model.from_pretrained(model_dir).cuda()
# model = model.eval() model = model.eval()
model = model.train() # control by @torch.no_grad() # model = model.train() # control by @torch.no_grad()
# 可指定不同的生成长度、top_p等相关超参 # 可指定不同的生成长度、top_p等相关超参
# model.generation_config = GenerationConfig.from_pretrained( # model.generation_config = GenerationConfig.from_pretrained(

View File

@ -40,7 +40,6 @@ from safetensors import safe_open
from safetensors.torch import load_file as safe_load_file from safetensors.torch import load_file as safe_load_file
from safetensors.torch import save_file as safe_save_file from safetensors.torch import save_file as safe_save_file
import sys import sys
sys.path.append("..") sys.path.append("..")
@ -235,45 +234,16 @@ class QWenModel(QWenPreTrainedModel):
eps=config.layer_norm_epsilon, eps=config.layer_norm_epsilon,
) )
def get_ntk_alpha(self, true_seq_len):
context_value = math.log(true_seq_len / self.seq_length, 2) + 1
ntk_alpha = 2 ** math.ceil(context_value) - 1
ntk_alpha = max(ntk_alpha, 1)
return ntk_alpha
def forward( def forward(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
): ):
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1]) input_ids = input_ids.view(-1, input_shape[-1])
batch_size = input_ids.shape[0] batch_size = input_ids.shape[0]
elif inputs_embeds is not None: hidden_states = self.wte(input_ids)
input_shape = inputs_embeds.size()[:-1]
batch_size = inputs_embeds.shape[0]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
hidden_states = inputs_embeds
kv_seq_len = hidden_states.size()[1] kv_seq_len = hidden_states.size()[1]
rotary_pos_emb_list = [self.rotary_emb(kv_seq_len, ntk_alpha=1.0)]
if self.training or not self.use_dynamic_ntk:
ntk_alpha_list = [1.0]
elif kv_seq_len != hidden_states.size()[1]:
ntk_alpha_list = self.rotary_emb._ntk_alpha_cached_list
else:
ntk_alpha_list = []
ntk_alpha = self.get_ntk_alpha(kv_seq_len)
ntk_alpha_list.append(ntk_alpha)
self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list
rotary_pos_emb_list = [self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list]
hidden_states = self.drop(hidden_states) hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),) output_shape = input_shape + (hidden_states.size(-1),)
@ -296,19 +266,17 @@ class QWenLMHeadModel(nn.Module):
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.generation_config = GenerationConfig.from_model_config(config) self.generation_config = GenerationConfig.from_model_config(config)
def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, **kwargs): def prepare_inputs_for_generation(self, input_ids, **kwargs):
model_inputs = {"input_ids": input_ids} model_inputs = {"input_ids": input_ids}
return model_inputs return model_inputs
def forward( def forward(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]: ) -> Union[Tuple, CausalLMOutputWithPast]:
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
input_ids, input_ids,
inputs_embeds=inputs_embeds,
) )
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
@ -520,7 +488,7 @@ class RotaryEmbedding(torch.nn.Module):
self._rotary_pos_emb_cache = None self._rotary_pos_emb_cache = None
self._seq_len_cached = 0 self._seq_len_cached = 0
self._ntk_alpha_cached = 1.0 self._ntk_alpha_cached = 1.0
self._ntk_alpha_cached_list = [1.0] # self._ntk_alpha_cached_list = [1.0]
def update_rotary_pos_emb_cache(self, seqlen, ntk_alpha=1.0): def update_rotary_pos_emb_cache(self, seqlen, ntk_alpha=1.0):
if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached: if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached: