Update code.
This commit is contained in:
		
							parent
							
								
									fa7078b72d
								
							
						
					
					
						commit
						0bc7bc90b1
					
				
							
								
								
									
										51
									
								
								Readme.md
								
								
								
								
							
							
						
						
									
										51
									
								
								Readme.md
								
								
								
								
							|  | @ -5,24 +5,24 @@ | |||
| 
 | ||||
| input_ids = tokenizer.build_chat_input(query, history=history, role=role) | ||||
| 
 | ||||
| for | ||||
|   input_ids -> [1, 6]  1:batch_num  6:sequence_length | ||||
|   inputs_embeds -> [6, 1, 4096]  4096:hidden_size | ||||
|   rotary_pos_emb -> [6, 1, 32, 2]  32:pos的编码维度  2:cos+sin | ||||
| for: | ||||
|     input_ids -> [1, 6]  1:batch_num  6:sequence_length | ||||
|     inputs_embeds -> [6, 1, 4096]  4096:hidden_size | ||||
|     rotary_pos_emb -> [6, 1, 32, 2]  32:pos的编码维度  2:cos+sin | ||||
| 
 | ||||
|   hidden_states = inputs_embeds | ||||
|   for layers : GLMBlock(hidden_states, rotary_pos_emb) | ||||
|   hidden_states = RMSNorm(hidden_states) | ||||
|   hidden_states = hidden_states[-1:] 截取最后一个sequence | ||||
|   lm_logits = self.output_layer(hidden_states) | ||||
|   lm_logits = lm_logits.transpose(0, 1).contiguous()  -> [1, 1, 65024] | ||||
|     hidden_states = inputs_embeds | ||||
|     for layers : GLMBlock(hidden_states, rotary_pos_emb) | ||||
|     hidden_states = RMSNorm(hidden_states) | ||||
|     hidden_states = hidden_states[-1:] 截取最后一个sequence | ||||
|     lm_logits = self.output_layer(hidden_states) | ||||
|     lm_logits = lm_logits.transpose(0, 1).contiguous()  -> [1, 1, 65024] | ||||
| 
 | ||||
|   probs = softmax(lm_logits) -> [1, 65024] | ||||
|   next_tokens = torch.multinomial(probs, num_samples=1)  采样  -> [1]  1:batch_num  | ||||
|     probs = softmax(lm_logits) -> [1, 65024] | ||||
|     next_tokens = torch.multinomial(probs, num_samples=1)  采样  -> [1]  1:batch_num  | ||||
| 
 | ||||
|   if next_tokens == eos_token_id 推理结束退出循环 | ||||
|     if next_tokens == eos_token_id 推理结束退出循环 | ||||
| 
 | ||||
|   input_ids = torch.cat([input_ids, next_tokens)  -> [1, 7]  1:batch_num | ||||
|     input_ids = torch.cat([input_ids, next_tokens)  -> [1, 7]  1:batch_num | ||||
| 
 | ||||
| response = tokenizer.decode(outputs) | ||||
| 
 | ||||
|  | @ -40,19 +40,6 @@ Linear(hidden_states)  no bias  ->  [6, 1, 27392] | |||
| silu (x) = [6, 1, 13696] * sigmoid([6, 1, 13696]) | ||||
| Linear(intermediate_parallel)  no bias  ->  [6, 1, 4096] | ||||
| 
 | ||||
| ## core_attention | ||||
| 
 | ||||
| query_layer=query_layer.permute(1, 2, 0, 3)  ->  [1, 32, 6, 128] | ||||
| key_layer=key_layer.permute(1, 2, 0, 3)  ->  [1, 32, 6, 128] | ||||
| value_layer=value_layer.permute(1, 2, 0, 3)  ->  [1, 32, 6, 128] | ||||
| context_layer = scaled_dot_product_attention(query_layer, key_layer, value_layer)  ->  [1, 32, 6, 128] | ||||
|     softmax(QK^T/sqrt(in_dim))V | ||||
|     att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) | ||||
|     att = F.softmax(att, dim=-1) | ||||
|     y = att @ v  ->  (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) | ||||
| context_layer = context_layer.permute(2, 0, 1, 3) | ||||
| context_layer = context_layer.reshape()  ->  [6, 1, 4096] | ||||
| 
 | ||||
| ## self_attention | ||||
| 
 | ||||
| hidden_states: [s, b, h] | ||||
|  | @ -74,7 +61,15 @@ value_layer = value_layer.unsqueeze(-2)  ->  [6, 1, 2, 1, 128] | |||
| value_layer = value_layer.expand  ->  [6, 1, 2, 16, 128] | ||||
| value_layer = value_layer.contiguous().view  ->  [6, 1, 32, 128] | ||||
| 
 | ||||
| context_layer = self.core_attention(query_layer, key_layer, value_layer)  ->  [6, 1, 4096] | ||||
| query_layer permute(1, 2, 0, 3)  ->  [1, 32, 6, 128] | ||||
| key_layer permute(1, 2, 0, 3)  ->  [1, 32, 6, 128] | ||||
| value_layer permute(1, 2, 0, 3)  ->  [1, 32, 6, 128] | ||||
| context_layer = scaled_dot_product_attention(query_layer, key_layer, value_layer)  ->  [1, 32, 6, 128] | ||||
|     att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) | ||||
|     att = F.softmax(att, dim=-1) | ||||
|     y = att @ v  ->  (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) | ||||
| context_layer = context_layer.permute(2, 0, 1, 3).reshape()  ->  [6, 1, 4096] | ||||
| 
 | ||||
| return Linear(context_layer)  ->  [6, 1, 4096] | ||||
| 
 | ||||
| ## GLMBlock | ||||
|  |  | |||
|  | @ -22,8 +22,6 @@ from transformers.generation import GenerationConfig | |||
| from chatglm import ChatGLMConfig | ||||
| from tools import show | ||||
| 
 | ||||
| WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json" | ||||
| 
 | ||||
| 
 | ||||
| class RotaryEmbedding(nn.Module): | ||||
|     def __init__(self, dim: int, original_impl=False, device=None, dtype=None): | ||||
|  | @ -487,60 +485,26 @@ class ChatGLMForConditionalGeneration(nn.Module): | |||
|         self.generation_config = GenerationConfig.from_model_config(config) | ||||
| 
 | ||||
|     def from_pretrained( | ||||
|         cls, | ||||
|         pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], | ||||
|         config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, | ||||
|         **kwargs, | ||||
|         cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] | ||||
|     ): | ||||
|         state_dict = kwargs.pop("state_dict", None) | ||||
|         _ = kwargs.pop("mirror", None) | ||||
|         load_in_8bit = kwargs.pop("load_in_8bit", False) | ||||
|         load_in_4bit = kwargs.pop("load_in_4bit", False) | ||||
|         subfolder = kwargs.pop("subfolder", "") | ||||
|         load_in_8bit = False | ||||
|         load_in_4bit = False | ||||
| 
 | ||||
|         # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the | ||||
|         # index of the files. | ||||
|         sharded_metadata = None | ||||
| 
 | ||||
|         # if pretrained_model_name_or_path is not None: | ||||
|         pretrained_model_name_or_path = str(pretrained_model_name_or_path) | ||||
|         archive_file = os.path.join( | ||||
|             pretrained_model_name_or_path, | ||||
|             subfolder, | ||||
|             WEIGHTS_INDEX_NAME, | ||||
|         resolved_archive_file = os.path.join( | ||||
|             pretrained_model_name_or_path, "pytorch_model.bin.index.json" | ||||
|         ) | ||||
|         print(f"loading weights file {archive_file}") | ||||
|         resolved_archive_file = archive_file | ||||
| 
 | ||||
|         print(f"loading weights file {resolved_archive_file}") | ||||
|         with open(resolved_archive_file, "r") as f: | ||||
|             index = json.loads(f.read()) | ||||
| 
 | ||||
|         shard_filenames = sorted(set(index["weight_map"].values())) | ||||
|         sharded_metadata = index["metadata"] | ||||
|         sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys()) | ||||
|         sharded_metadata["weight_map"] = index["weight_map"].copy() | ||||
| 
 | ||||
|         resolved_archive_file = [ | ||||
|             os.path.join(pretrained_model_name_or_path, subfolder, f) | ||||
|             for f in shard_filenames | ||||
|             os.path.join(pretrained_model_name_or_path, f) for f in shard_filenames | ||||
|         ] | ||||
| 
 | ||||
|         loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] | ||||
|         config.name_or_path = pretrained_model_name_or_path | ||||
|         config = copy.deepcopy(config) | ||||
|         # make sure we use the model's config since the __init__ call might have copied it | ||||
|         config = cls.config | ||||
|         # restore default dtype | ||||
|         model = cls._load_pretrained_model( | ||||
|             state_dict, | ||||
|             loaded_state_dict_keys, | ||||
|             resolved_archive_file, | ||||
|             pretrained_model_name_or_path, | ||||
|         ) | ||||
|         model = cls._load_pretrained_model(resolved_archive_file) | ||||
|         model.is_loaded_in_4bit = load_in_4bit | ||||
|         model.is_loaded_in_8bit = load_in_8bit | ||||
|         # Set model in evaluation mode to deactivate DropOut modules by default | ||||
|         model.eval() | ||||
|         model.eval()  # Set model in evaluation mode to deactivate DropOut modules by default | ||||
|         return model | ||||
| 
 | ||||
|     def _load_state_dict_into_model(self, model_to_load, state_dict, start_prefix): | ||||
|  | @ -564,40 +528,7 @@ class ChatGLMForConditionalGeneration(nn.Module): | |||
|         del state_dict | ||||
|         return error_msgs | ||||
| 
 | ||||
|     def _load_pretrained_model( | ||||
|         cls, | ||||
|         state_dict, | ||||
|         loaded_keys, | ||||
|         resolved_archive_file, | ||||
|         pretrained_model_name_or_path, | ||||
|     ): | ||||
|         # Retrieve missing & unexpected_keys | ||||
|         model_state_dict = cls.state_dict() | ||||
|         expected_keys = list(model_state_dict.keys()) | ||||
|         loaded_keys = [key for key in loaded_keys] | ||||
| 
 | ||||
|         unexpected_keys = set(loaded_keys) - set(expected_keys) | ||||
|         model_buffers = {n for n, _ in cls.named_buffers()} | ||||
|         unexpected_keys = list(unexpected_keys - model_buffers) | ||||
| 
 | ||||
|         ptrs = collections.defaultdict(list) | ||||
|         for name, tensor in cls.state_dict().items(): | ||||
|             ptrs[(tensor.device, storage_ptr(tensor), storage_size(tensor))].append( | ||||
|                 name | ||||
|             ) | ||||
|         # retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights. | ||||
| 
 | ||||
|         _loaded_keys = loaded_keys | ||||
|         for module_name, module in cls.named_modules(): | ||||
|             loaded_keys = [ | ||||
|                 k.replace(f"{module_name}.", "") | ||||
|                 for k in _loaded_keys | ||||
|                 if k.startswith(f"{module_name}.") | ||||
|             ] | ||||
|             if len(set(module.state_dict().keys()) - set(loaded_keys)) == 0: | ||||
|                 module._is_hf_initialized = True | ||||
| 
 | ||||
|         # Make sure we are able to load base models as well as derived models (with heads) | ||||
|     def _load_pretrained_model(cls, resolved_archive_file): | ||||
|         start_prefix = "" | ||||
|         model_to_load = cls | ||||
|         error_msgs = [] | ||||
|  | @ -606,26 +537,17 @@ class ChatGLMForConditionalGeneration(nn.Module): | |||
|                 resolved_archive_file, desc="Loading checkpoint shards" | ||||
|             ) | ||||
|         for shard_file in resolved_archive_file: | ||||
|             # Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload. | ||||
|             state_dict = torch.load(shard_file, map_location="cpu") | ||||
|             # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not | ||||
|             # matching the weights in the model. | ||||
| 
 | ||||
|             error_msgs += cls._load_state_dict_into_model( | ||||
|                 model_to_load, state_dict, start_prefix | ||||
|             ) | ||||
|             # force memory release | ||||
|             del state_dict | ||||
|             del state_dict  # force memory release | ||||
|             gc.collect() | ||||
| 
 | ||||
|         print( | ||||
|             f"All model checkpoint weights were used when initializing {cls.__class__.__name__}.\n" | ||||
|         ) | ||||
|         print( | ||||
|             f"All the weights of {cls.__class__.__name__} were initialized from the model checkpoint at" | ||||
|             f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint" | ||||
|             f" was trained on, you can already use {cls.__class__.__name__} for predictions without further" | ||||
|             " training." | ||||
|         ) | ||||
|         return cls | ||||
| 
 | ||||
|     @torch.inference_mode() | ||||
|  |  | |||
							
								
								
									
										2
									
								
								demo.py
								
								
								
								
							
							
						
						
									
										2
									
								
								demo.py
								
								
								
								
							|  | @ -37,7 +37,7 @@ init_kwargs["name_or_path"] = pretrained_model_name_or_path | |||
| tokenizer = ChatGLMTokenizer(*init_inputs, **init_kwargs) | ||||
| 
 | ||||
| 
 | ||||
| glm = glm.from_pretrained(pretrained_model_name_or_path, config=config).half().cuda() | ||||
| glm = glm.from_pretrained(pretrained_model_name_or_path).half().cuda() | ||||
| glm = glm.eval() | ||||
| query = "colin" | ||||
| response, history = glm.chat(tokenizer, query, history=[]) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue