From 0bc7bc90b1437ab13d5befb1888791c5c021e62e Mon Sep 17 00:00:00 2001 From: Colin Date: Mon, 25 Dec 2023 17:26:19 +0800 Subject: [PATCH] Update code. --- Readme.md | 51 ++++++++---------- chatglm/modeling_chatglm.py | 102 +++++------------------------------- demo.py | 2 +- 3 files changed, 36 insertions(+), 119 deletions(-) diff --git a/Readme.md b/Readme.md index 7223c67..4ab6753 100644 --- a/Readme.md +++ b/Readme.md @@ -5,24 +5,24 @@ input_ids = tokenizer.build_chat_input(query, history=history, role=role) -for - input_ids -> [1, 6] 1:batch_num 6:sequence_length - inputs_embeds -> [6, 1, 4096] 4096:hidden_size - rotary_pos_emb -> [6, 1, 32, 2] 32:pos的编码维度 2:cos+sin +for: + input_ids -> [1, 6] 1:batch_num 6:sequence_length + inputs_embeds -> [6, 1, 4096] 4096:hidden_size + rotary_pos_emb -> [6, 1, 32, 2] 32:pos的编码维度 2:cos+sin - hidden_states = inputs_embeds - for layers : GLMBlock(hidden_states, rotary_pos_emb) - hidden_states = RMSNorm(hidden_states) - hidden_states = hidden_states[-1:] 截取最后一个sequence - lm_logits = self.output_layer(hidden_states) - lm_logits = lm_logits.transpose(0, 1).contiguous() -> [1, 1, 65024] + hidden_states = inputs_embeds + for layers : GLMBlock(hidden_states, rotary_pos_emb) + hidden_states = RMSNorm(hidden_states) + hidden_states = hidden_states[-1:] 截取最后一个sequence + lm_logits = self.output_layer(hidden_states) + lm_logits = lm_logits.transpose(0, 1).contiguous() -> [1, 1, 65024] - probs = softmax(lm_logits) -> [1, 65024] - next_tokens = torch.multinomial(probs, num_samples=1) 采样 -> [1] 1:batch_num + probs = softmax(lm_logits) -> [1, 65024] + next_tokens = torch.multinomial(probs, num_samples=1) 采样 -> [1] 1:batch_num - if next_tokens == eos_token_id 推理结束退出循环 + if next_tokens == eos_token_id 推理结束退出循环 - input_ids = torch.cat([input_ids, next_tokens) -> [1, 7] 1:batch_num + input_ids = torch.cat([input_ids, next_tokens) -> [1, 7] 1:batch_num response = tokenizer.decode(outputs) @@ -40,19 +40,6 @@ Linear(hidden_states) no bias -> [6, 1, 27392] silu (x) = [6, 1, 13696] * sigmoid([6, 1, 13696]) Linear(intermediate_parallel) no bias -> [6, 1, 4096] -## core_attention - -query_layer=query_layer.permute(1, 2, 0, 3) -> [1, 32, 6, 128] -key_layer=key_layer.permute(1, 2, 0, 3) -> [1, 32, 6, 128] -value_layer=value_layer.permute(1, 2, 0, 3) -> [1, 32, 6, 128] -context_layer = scaled_dot_product_attention(query_layer, key_layer, value_layer) -> [1, 32, 6, 128] - softmax(QK^T/sqrt(in_dim))V - att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) - att = F.softmax(att, dim=-1) - y = att @ v -> (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) -context_layer = context_layer.permute(2, 0, 1, 3) -context_layer = context_layer.reshape() -> [6, 1, 4096] - ## self_attention hidden_states: [s, b, h] @@ -74,7 +61,15 @@ value_layer = value_layer.unsqueeze(-2) -> [6, 1, 2, 1, 128] value_layer = value_layer.expand -> [6, 1, 2, 16, 128] value_layer = value_layer.contiguous().view -> [6, 1, 32, 128] -context_layer = self.core_attention(query_layer, key_layer, value_layer) -> [6, 1, 4096] +query_layer permute(1, 2, 0, 3) -> [1, 32, 6, 128] +key_layer permute(1, 2, 0, 3) -> [1, 32, 6, 128] +value_layer permute(1, 2, 0, 3) -> [1, 32, 6, 128] +context_layer = scaled_dot_product_attention(query_layer, key_layer, value_layer) -> [1, 32, 6, 128] + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = F.softmax(att, dim=-1) + y = att @ v -> (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) +context_layer = context_layer.permute(2, 0, 1, 3).reshape() -> [6, 1, 4096] + return Linear(context_layer) -> [6, 1, 4096] ## GLMBlock diff --git a/chatglm/modeling_chatglm.py b/chatglm/modeling_chatglm.py index c28edd1..ca8af20 100644 --- a/chatglm/modeling_chatglm.py +++ b/chatglm/modeling_chatglm.py @@ -22,8 +22,6 @@ from transformers.generation import GenerationConfig from chatglm import ChatGLMConfig from tools import show -WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json" - class RotaryEmbedding(nn.Module): def __init__(self, dim: int, original_impl=False, device=None, dtype=None): @@ -487,60 +485,26 @@ class ChatGLMForConditionalGeneration(nn.Module): self.generation_config = GenerationConfig.from_model_config(config) def from_pretrained( - cls, - pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], - config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, - **kwargs, + cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] ): - state_dict = kwargs.pop("state_dict", None) - _ = kwargs.pop("mirror", None) - load_in_8bit = kwargs.pop("load_in_8bit", False) - load_in_4bit = kwargs.pop("load_in_4bit", False) - subfolder = kwargs.pop("subfolder", "") + load_in_8bit = False + load_in_4bit = False - # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the - # index of the files. - sharded_metadata = None - - # if pretrained_model_name_or_path is not None: pretrained_model_name_or_path = str(pretrained_model_name_or_path) - archive_file = os.path.join( - pretrained_model_name_or_path, - subfolder, - WEIGHTS_INDEX_NAME, + resolved_archive_file = os.path.join( + pretrained_model_name_or_path, "pytorch_model.bin.index.json" ) - print(f"loading weights file {archive_file}") - resolved_archive_file = archive_file - + print(f"loading weights file {resolved_archive_file}") with open(resolved_archive_file, "r") as f: index = json.loads(f.read()) - shard_filenames = sorted(set(index["weight_map"].values())) - sharded_metadata = index["metadata"] - sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys()) - sharded_metadata["weight_map"] = index["weight_map"].copy() - resolved_archive_file = [ - os.path.join(pretrained_model_name_or_path, subfolder, f) - for f in shard_filenames + os.path.join(pretrained_model_name_or_path, f) for f in shard_filenames ] - - loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] - config.name_or_path = pretrained_model_name_or_path - config = copy.deepcopy(config) - # make sure we use the model's config since the __init__ call might have copied it - config = cls.config - # restore default dtype - model = cls._load_pretrained_model( - state_dict, - loaded_state_dict_keys, - resolved_archive_file, - pretrained_model_name_or_path, - ) + model = cls._load_pretrained_model(resolved_archive_file) model.is_loaded_in_4bit = load_in_4bit model.is_loaded_in_8bit = load_in_8bit - # Set model in evaluation mode to deactivate DropOut modules by default - model.eval() + model.eval() # Set model in evaluation mode to deactivate DropOut modules by default return model def _load_state_dict_into_model(self, model_to_load, state_dict, start_prefix): @@ -564,40 +528,7 @@ class ChatGLMForConditionalGeneration(nn.Module): del state_dict return error_msgs - def _load_pretrained_model( - cls, - state_dict, - loaded_keys, - resolved_archive_file, - pretrained_model_name_or_path, - ): - # Retrieve missing & unexpected_keys - model_state_dict = cls.state_dict() - expected_keys = list(model_state_dict.keys()) - loaded_keys = [key for key in loaded_keys] - - unexpected_keys = set(loaded_keys) - set(expected_keys) - model_buffers = {n for n, _ in cls.named_buffers()} - unexpected_keys = list(unexpected_keys - model_buffers) - - ptrs = collections.defaultdict(list) - for name, tensor in cls.state_dict().items(): - ptrs[(tensor.device, storage_ptr(tensor), storage_size(tensor))].append( - name - ) - # retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights. - - _loaded_keys = loaded_keys - for module_name, module in cls.named_modules(): - loaded_keys = [ - k.replace(f"{module_name}.", "") - for k in _loaded_keys - if k.startswith(f"{module_name}.") - ] - if len(set(module.state_dict().keys()) - set(loaded_keys)) == 0: - module._is_hf_initialized = True - - # Make sure we are able to load base models as well as derived models (with heads) + def _load_pretrained_model(cls, resolved_archive_file): start_prefix = "" model_to_load = cls error_msgs = [] @@ -606,26 +537,17 @@ class ChatGLMForConditionalGeneration(nn.Module): resolved_archive_file, desc="Loading checkpoint shards" ) for shard_file in resolved_archive_file: - # Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload. state_dict = torch.load(shard_file, map_location="cpu") - # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not - # matching the weights in the model. + error_msgs += cls._load_state_dict_into_model( model_to_load, state_dict, start_prefix ) - # force memory release - del state_dict + del state_dict # force memory release gc.collect() print( f"All model checkpoint weights were used when initializing {cls.__class__.__name__}.\n" ) - print( - f"All the weights of {cls.__class__.__name__} were initialized from the model checkpoint at" - f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint" - f" was trained on, you can already use {cls.__class__.__name__} for predictions without further" - " training." - ) return cls @torch.inference_mode() diff --git a/demo.py b/demo.py index ad1215c..23a4b3a 100644 --- a/demo.py +++ b/demo.py @@ -37,7 +37,7 @@ init_kwargs["name_or_path"] = pretrained_model_name_or_path tokenizer = ChatGLMTokenizer(*init_inputs, **init_kwargs) -glm = glm.from_pretrained(pretrained_model_name_or_path, config=config).half().cuda() +glm = glm.from_pretrained(pretrained_model_name_or_path).half().cuda() glm = glm.eval() query = "colin" response, history = glm.chat(tokenizer, query, history=[])