Refine model.
This commit is contained in:
parent
84938e565e
commit
185caa12e9
|
@ -392,19 +392,15 @@ class GLMTransformer(torch.nn.Module):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
rotary_pos_emb,
|
rotary_pos_emb
|
||||||
use_cache: Optional[bool] = True,
|
|
||||||
):
|
):
|
||||||
kv_caches = [None for _ in range(self.num_layers)]
|
kv_caches = [None for _ in range(self.num_layers)]
|
||||||
presents = () if use_cache else None
|
|
||||||
|
|
||||||
for index in range(self.num_layers):
|
for index in range(self.num_layers):
|
||||||
layer = self.layers[index]
|
layer = self.layers[index]
|
||||||
hidden_states, kv_cache = layer(
|
hidden_states, kv_cache = layer(
|
||||||
hidden_states, rotary_pos_emb, kv_cache=kv_caches[index]
|
hidden_states, rotary_pos_emb, kv_cache=kv_caches[index]
|
||||||
)
|
)
|
||||||
if use_cache:
|
|
||||||
presents = presents + (kv_cache,)
|
|
||||||
hidden_states = self.final_layernorm(hidden_states)
|
hidden_states = self.final_layernorm(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
@ -472,7 +468,6 @@ class ChatGLMModel(nn.Module):
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids: Optional[torch.Tensor] = None,
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_last_logit: Optional[bool] = False,
|
return_last_logit: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
|
@ -481,7 +476,6 @@ class ChatGLMModel(nn.Module):
|
||||||
if output_hidden_states is not None
|
if output_hidden_states is not None
|
||||||
else self.config.output_hidden_states
|
else self.config.output_hidden_states
|
||||||
)
|
)
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
||||||
batch_size, seq_length = input_ids.shape
|
batch_size, seq_length = input_ids.shape
|
||||||
inputs_embeds = self.embedding(input_ids)
|
inputs_embeds = self.embedding(input_ids)
|
||||||
|
|
||||||
|
@ -493,8 +487,7 @@ class ChatGLMModel(nn.Module):
|
||||||
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
|
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
|
||||||
hidden_states = self.encoder(
|
hidden_states = self.encoder(
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
rotary_pos_emb=rotary_pos_emb,
|
rotary_pos_emb=rotary_pos_emb
|
||||||
use_cache=use_cache,
|
|
||||||
)
|
)
|
||||||
if return_last_logit:
|
if return_last_logit:
|
||||||
hidden_states = hidden_states[-1:]
|
hidden_states = hidden_states[-1:]
|
||||||
|
@ -683,8 +676,7 @@ class ChatGLMForConditionalGeneration(nn.Module):
|
||||||
input_ids,
|
input_ids,
|
||||||
pad_token_id=generation_config.pad_token_id,
|
pad_token_id=generation_config.pad_token_id,
|
||||||
eos_token_id=generation_config.eos_token_id,
|
eos_token_id=generation_config.eos_token_id,
|
||||||
output_hidden_states=generation_config.output_hidden_states,
|
output_hidden_states=generation_config.output_hidden_states
|
||||||
use_cache=generation_config.use_cache,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) : -1]
|
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) : -1]
|
||||||
|
@ -697,8 +689,7 @@ class ChatGLMForConditionalGeneration(nn.Module):
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
pad_token_id: Optional[int] = None,
|
pad_token_id: Optional[int] = None,
|
||||||
eos_token_id: Optional[Union[int, List[int]]] = None,
|
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
):
|
):
|
||||||
if isinstance(eos_token_id, int):
|
if isinstance(eos_token_id, int):
|
||||||
eos_token_id = [eos_token_id]
|
eos_token_id = [eos_token_id]
|
||||||
|
@ -717,12 +708,10 @@ class ChatGLMForConditionalGeneration(nn.Module):
|
||||||
.unsqueeze(0)
|
.unsqueeze(0)
|
||||||
.repeat(batch_size, 1)
|
.repeat(batch_size, 1)
|
||||||
)
|
)
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
||||||
model_inputs = {
|
model_inputs = {
|
||||||
"input_ids": input_ids_in,
|
"input_ids": input_ids_in,
|
||||||
"position_ids": position_ids_in,
|
"position_ids": position_ids_in,
|
||||||
"return_last_logit": True,
|
"return_last_logit": True
|
||||||
"use_cache": use_cache,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
logits = self.transformer(
|
logits = self.transformer(
|
||||||
|
|
Loading…
Reference in New Issue