Refine model.
This commit is contained in:
		
							parent
							
								
									84938e565e
								
							
						
					
					
						commit
						185caa12e9
					
				|  | @ -392,19 +392,15 @@ class GLMTransformer(torch.nn.Module): | |||
|     def forward( | ||||
|         self, | ||||
|         hidden_states, | ||||
|         rotary_pos_emb, | ||||
|         use_cache: Optional[bool] = True, | ||||
|         rotary_pos_emb | ||||
|     ): | ||||
|         kv_caches = [None for _ in range(self.num_layers)] | ||||
|         presents = () if use_cache else None | ||||
| 
 | ||||
|         for index in range(self.num_layers): | ||||
|             layer = self.layers[index] | ||||
|             hidden_states, kv_cache = layer( | ||||
|                 hidden_states, rotary_pos_emb, kv_cache=kv_caches[index] | ||||
|             ) | ||||
|             if use_cache: | ||||
|                 presents = presents + (kv_cache,) | ||||
|         hidden_states = self.final_layernorm(hidden_states) | ||||
|         return hidden_states | ||||
| 
 | ||||
|  | @ -472,7 +468,6 @@ class ChatGLMModel(nn.Module): | |||
|         self, | ||||
|         input_ids, | ||||
|         position_ids: Optional[torch.Tensor] = None, | ||||
|         use_cache: Optional[bool] = None, | ||||
|         output_hidden_states: Optional[bool] = None, | ||||
|         return_last_logit: Optional[bool] = False, | ||||
|     ): | ||||
|  | @ -481,7 +476,6 @@ class ChatGLMModel(nn.Module): | |||
|             if output_hidden_states is not None | ||||
|             else self.config.output_hidden_states | ||||
|         ) | ||||
|         use_cache = use_cache if use_cache is not None else self.config.use_cache | ||||
|         batch_size, seq_length = input_ids.shape | ||||
|         inputs_embeds = self.embedding(input_ids) | ||||
| 
 | ||||
|  | @ -493,8 +487,7 @@ class ChatGLMModel(nn.Module): | |||
|         rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() | ||||
|         hidden_states = self.encoder( | ||||
|             inputs_embeds, | ||||
|             rotary_pos_emb=rotary_pos_emb, | ||||
|             use_cache=use_cache, | ||||
|             rotary_pos_emb=rotary_pos_emb | ||||
|         ) | ||||
|         if return_last_logit: | ||||
|             hidden_states = hidden_states[-1:] | ||||
|  | @ -683,8 +676,7 @@ class ChatGLMForConditionalGeneration(nn.Module): | |||
|             input_ids, | ||||
|             pad_token_id=generation_config.pad_token_id, | ||||
|             eos_token_id=generation_config.eos_token_id, | ||||
|             output_hidden_states=generation_config.output_hidden_states, | ||||
|             use_cache=generation_config.use_cache, | ||||
|             output_hidden_states=generation_config.output_hidden_states | ||||
|         ) | ||||
| 
 | ||||
|         outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) : -1] | ||||
|  | @ -697,8 +689,7 @@ class ChatGLMForConditionalGeneration(nn.Module): | |||
|         input_ids: torch.LongTensor, | ||||
|         pad_token_id: Optional[int] = None, | ||||
|         eos_token_id: Optional[Union[int, List[int]]] = None, | ||||
|         output_hidden_states: Optional[bool] = None, | ||||
|         use_cache: Optional[bool] = None, | ||||
|         output_hidden_states: Optional[bool] = None | ||||
|     ): | ||||
|         if isinstance(eos_token_id, int): | ||||
|             eos_token_id = [eos_token_id] | ||||
|  | @ -717,12 +708,10 @@ class ChatGLMForConditionalGeneration(nn.Module): | |||
|                 .unsqueeze(0) | ||||
|                 .repeat(batch_size, 1) | ||||
|             ) | ||||
|             use_cache = use_cache if use_cache is not None else self.config.use_cache | ||||
|             model_inputs = { | ||||
|                 "input_ids": input_ids_in, | ||||
|                 "position_ids": position_ids_in, | ||||
|                 "return_last_logit": True, | ||||
|                 "use_cache": use_cache, | ||||
|                 "return_last_logit": True | ||||
|             } | ||||
| 
 | ||||
|             logits = self.transformer( | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue