Refine model of qwen.
This commit is contained in:
parent
4d493014ba
commit
40ae899515
|
@ -54,8 +54,8 @@ print(model)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
||||||
model = model.from_pretrained(model_dir).cuda()
|
model = model.from_pretrained(model_dir).cuda()
|
||||||
|
|
||||||
# model = model.eval()
|
model = model.eval()
|
||||||
model = model.train() # control by @torch.no_grad()
|
# model = model.train() # control by @torch.no_grad()
|
||||||
|
|
||||||
# 可指定不同的生成长度、top_p等相关超参
|
# 可指定不同的生成长度、top_p等相关超参
|
||||||
# model.generation_config = GenerationConfig.from_pretrained(
|
# model.generation_config = GenerationConfig.from_pretrained(
|
||||||
|
|
|
@ -40,7 +40,6 @@ from safetensors import safe_open
|
||||||
from safetensors.torch import load_file as safe_load_file
|
from safetensors.torch import load_file as safe_load_file
|
||||||
from safetensors.torch import save_file as safe_save_file
|
from safetensors.torch import save_file as safe_save_file
|
||||||
|
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
sys.path.append("..")
|
sys.path.append("..")
|
||||||
|
@ -235,45 +234,16 @@ class QWenModel(QWenPreTrainedModel):
|
||||||
eps=config.layer_norm_epsilon,
|
eps=config.layer_norm_epsilon,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_ntk_alpha(self, true_seq_len):
|
|
||||||
context_value = math.log(true_seq_len / self.seq_length, 2) + 1
|
|
||||||
ntk_alpha = 2 ** math.ceil(context_value) - 1
|
|
||||||
ntk_alpha = max(ntk_alpha, 1)
|
|
||||||
return ntk_alpha
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
):
|
):
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
input_shape = input_ids.size()
|
||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
input_ids = input_ids.view(-1, input_shape[-1])
|
||||||
elif input_ids is not None:
|
batch_size = input_ids.shape[0]
|
||||||
input_shape = input_ids.size()
|
hidden_states = self.wte(input_ids)
|
||||||
input_ids = input_ids.view(-1, input_shape[-1])
|
|
||||||
batch_size = input_ids.shape[0]
|
|
||||||
elif inputs_embeds is not None:
|
|
||||||
input_shape = inputs_embeds.size()[:-1]
|
|
||||||
batch_size = inputs_embeds.shape[0]
|
|
||||||
else:
|
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.wte(input_ids)
|
|
||||||
hidden_states = inputs_embeds
|
|
||||||
|
|
||||||
kv_seq_len = hidden_states.size()[1]
|
kv_seq_len = hidden_states.size()[1]
|
||||||
|
rotary_pos_emb_list = [self.rotary_emb(kv_seq_len, ntk_alpha=1.0)]
|
||||||
if self.training or not self.use_dynamic_ntk:
|
|
||||||
ntk_alpha_list = [1.0]
|
|
||||||
elif kv_seq_len != hidden_states.size()[1]:
|
|
||||||
ntk_alpha_list = self.rotary_emb._ntk_alpha_cached_list
|
|
||||||
else:
|
|
||||||
ntk_alpha_list = []
|
|
||||||
ntk_alpha = self.get_ntk_alpha(kv_seq_len)
|
|
||||||
ntk_alpha_list.append(ntk_alpha)
|
|
||||||
self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list
|
|
||||||
rotary_pos_emb_list = [self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list]
|
|
||||||
|
|
||||||
hidden_states = self.drop(hidden_states)
|
hidden_states = self.drop(hidden_states)
|
||||||
output_shape = input_shape + (hidden_states.size(-1),)
|
output_shape = input_shape + (hidden_states.size(-1),)
|
||||||
|
@ -296,19 +266,17 @@ class QWenLMHeadModel(nn.Module):
|
||||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
self.generation_config = GenerationConfig.from_model_config(config)
|
self.generation_config = GenerationConfig.from_model_config(config)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, **kwargs):
|
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
||||||
model_inputs = {"input_ids": input_ids}
|
model_inputs = {"input_ids": input_ids}
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
transformer_outputs = self.transformer(
|
transformer_outputs = self.transformer(
|
||||||
input_ids,
|
input_ids,
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
)
|
)
|
||||||
hidden_states = transformer_outputs[0]
|
hidden_states = transformer_outputs[0]
|
||||||
|
|
||||||
|
@ -377,16 +345,13 @@ class QWenLMHeadModel(nn.Module):
|
||||||
def _load_pretrained_model(cls, resolved_archive_file):
|
def _load_pretrained_model(cls, resolved_archive_file):
|
||||||
start_prefix = ""
|
start_prefix = ""
|
||||||
model_to_load = cls
|
model_to_load = cls
|
||||||
error_msgs = []
|
|
||||||
if len(resolved_archive_file) > 1:
|
if len(resolved_archive_file) > 1:
|
||||||
resolved_archive_file = tqdm_lib.tqdm(resolved_archive_file, desc="Loading checkpoint shards")
|
resolved_archive_file = tqdm_lib.tqdm(resolved_archive_file, desc="Loading checkpoint shards")
|
||||||
for shard_file in resolved_archive_file:
|
for shard_file in resolved_archive_file:
|
||||||
state_dict = safe_load_file(shard_file)
|
state_dict = safe_load_file(shard_file)
|
||||||
|
cls._load_state_dict_into_model(model_to_load, state_dict, start_prefix)
|
||||||
error_msgs += cls._load_state_dict_into_model(model_to_load, state_dict, start_prefix)
|
|
||||||
del state_dict # force memory release
|
del state_dict # force memory release
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
print(f"All model checkpoint weights were used when initializing {cls.__class__.__name__}.\n")
|
print(f"All model checkpoint weights were used when initializing {cls.__class__.__name__}.\n")
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
|
@ -520,7 +485,6 @@ class RotaryEmbedding(torch.nn.Module):
|
||||||
self._rotary_pos_emb_cache = None
|
self._rotary_pos_emb_cache = None
|
||||||
self._seq_len_cached = 0
|
self._seq_len_cached = 0
|
||||||
self._ntk_alpha_cached = 1.0
|
self._ntk_alpha_cached = 1.0
|
||||||
self._ntk_alpha_cached_list = [1.0]
|
|
||||||
|
|
||||||
def update_rotary_pos_emb_cache(self, seqlen, ntk_alpha=1.0):
|
def update_rotary_pos_emb_cache(self, seqlen, ntk_alpha=1.0):
|
||||||
if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached:
|
if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached:
|
||||||
|
|
Loading…
Reference in New Issue