import collections import math import copy import os import gc import json import torch import torch.utils.checkpoint import torch.nn.functional as F from torch import nn from torch.nn.utils import skip_init from typing import Optional, Tuple, Union, List, Dict, Any from tqdm import auto as tqdm_lib from safetensors.torch import storage_ptr, storage_size from transformers.configuration_utils import PretrainedConfig from transformers.generation import GenerationConfig from chatglm import ChatGLMConfig WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json" class RotaryEmbedding(nn.Module): def __init__(self, dim: int, original_impl=False, device=None, dtype=None): super().__init__() inv_freq = 1.0 / ( 10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim) ) self.register_buffer("inv_freq", inv_freq) self.dim = dim self.original_impl = original_impl def forward(self, max_seq_len: int, base: int = 10000): dtype = self.inv_freq.dtype device = self.inv_freq.device # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ theta = 1.0 / ( base ** ( torch.arange(0, self.dim, 2, dtype=torch.float, device=device) / self.dim ) ) # Create position indexes `[0, 1, ..., max_seq_len - 1]` seq_idx = torch.arange(max_seq_len, dtype=torch.float, device=device) # Calculate the product of position index and $\theta_i$ idx_theta = torch.outer(seq_idx, theta).float() cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) # this is to mimic the behaviour of complex32, else we will get different results if dtype in (torch.float16, torch.bfloat16, torch.int8): cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half() return cache class RMSNorm(torch.nn.Module): def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): super().__init__() self.weight = torch.nn.Parameter( torch.empty(normalized_shape, device=device, dtype=dtype) ) self.eps = eps def forward(self, hidden_states: torch.Tensor): input_dtype = hidden_states.dtype variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.eps) return (self.weight * hidden_states).to(input_dtype) class CoreAttention(torch.nn.Module): def __init__(self, config: ChatGLMConfig, layer_number): super(CoreAttention, self).__init__() self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 self.layer_number = max(1, layer_number) projection_size = config.kv_channels * config.num_attention_heads # Per attention head and per partition values. self.hidden_size_per_partition = projection_size self.hidden_size_per_attention_head = ( projection_size // config.num_attention_heads ) self.num_attention_heads_per_partition = config.num_attention_heads coeff = None self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) coeff = self.layer_number self.norm_factor *= coeff self.coeff = coeff self.attention_dropout = torch.nn.Dropout(config.attention_dropout) def forward(self, query_layer, key_layer, value_layer, attention_mask): query_layer, key_layer, value_layer = [ k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer] ] if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: context_layer = torch.nn.functional.scaled_dot_product_attention( query_layer, key_layer, value_layer, is_causal=True ) context_layer = context_layer.permute(2, 0, 1, 3) new_context_layer_shape = context_layer.size()[:-2] + ( self.hidden_size_per_partition, ) context_layer = context_layer.reshape(*new_context_layer_shape) return context_layer class SelfAttention(torch.nn.Module): """Parallel self-attention layer abstract class. Self-attention layer takes input with size [s, b, h] and returns output of the same size. """ def __init__(self, config: ChatGLMConfig, layer_number, device=None): super(SelfAttention, self).__init__() self.layer_number = max(1, layer_number) self.projection_size = config.kv_channels * config.num_attention_heads self.hidden_size_per_attention_head = ( self.projection_size // config.num_attention_heads ) self.num_attention_heads_per_partition = config.num_attention_heads self.multi_query_attention = config.multi_query_attention self.qkv_hidden_size = 3 * self.projection_size self.num_multi_query_groups_per_partition = config.multi_query_group_num self.qkv_hidden_size = ( self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num ) self.query_key_value = nn.Linear( config.hidden_size, self.qkv_hidden_size, bias=config.add_bias_linear or config.add_qkv_bias, device=device, dtype=config.torch_dtype, ) self.core_attention = CoreAttention(config, self.layer_number) self.dense = nn.Linear( self.projection_size, config.hidden_size, bias=config.add_bias_linear, device=device, dtype=config.torch_dtype, ) def apply_rotary_pos_emb( self, x: torch.Tensor, rope_cache: torch.Tensor ) -> torch.Tensor: # x: [sq, b, np, hn] sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3) rot_dim = rope_cache.shape[-2] * 2 x, x_pass = x[..., :rot_dim], x[..., rot_dim:] # truncate to support variable sizes rope_cache = rope_cache[:sq] xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) x_out2 = torch.stack( [ xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], ], -1, ) x_out2 = x_out2.flatten(3) return torch.cat((x_out2, x_pass), dim=-1) def forward(self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None): # hidden_states: [sq, b, h] # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] mixed_x_layer = self.query_key_value(hidden_states) (query_layer, key_layer, value_layer) = mixed_x_layer.split( [ self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, ], dim=-1, ) query_layer = query_layer.view( query_layer.size()[:-1] + ( self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, ) ) key_layer = key_layer.view( key_layer.size()[:-1] + ( self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head, ) ) value_layer = value_layer.view( value_layer.size()[:-1] + ( self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head, ) ) # apply relative positional encoding (rotary embedding) if rotary_pos_emb is not None: query_layer = self.apply_rotary_pos_emb(query_layer, rotary_pos_emb) key_layer = self.apply_rotary_pos_emb(key_layer, rotary_pos_emb) kv_cache = (key_layer, value_layer) key_layer = key_layer.unsqueeze(-2) key_layer = key_layer.expand( -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1, ) key_layer = key_layer.contiguous().view( key_layer.size()[:2] + ( self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, ) ) value_layer = value_layer.unsqueeze(-2) value_layer = value_layer.expand( -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1, ) value_layer = value_layer.contiguous().view( value_layer.size()[:2] + ( self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, ) ) # ================================== # core attention computation # ================================== context_layer = self.core_attention( query_layer, key_layer, value_layer, attention_mask ) # ================= # Output. [sq, b, h] # ================= output = self.dense(context_layer) return output, kv_cache class MLP(torch.nn.Module): """MLP. MLP will take the input with h hidden state, project it to 4*h hidden dimension, perform nonlinear transformation, and project the state back into h hidden dimension. """ def __init__(self, config: ChatGLMConfig, device=None): super(MLP, self).__init__() self.add_bias = config.add_bias_linear # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf self.dense_h_to_4h = nn.Linear( config.hidden_size, config.ffn_hidden_size * 2, bias=self.add_bias, device=device, dtype=config.torch_dtype, ) def swiglu(x): x = torch.chunk(x, 2, dim=-1) return F.silu(x[0]) * x[1] self.activation_func = swiglu # Project back to h. self.dense_4h_to_h = nn.Linear( config.ffn_hidden_size, config.hidden_size, bias=self.add_bias, device=device, dtype=config.torch_dtype, ) def forward(self, hidden_states): # [s, b, 4hp] intermediate_parallel = self.dense_h_to_4h(hidden_states) intermediate_parallel = self.activation_func(intermediate_parallel) # [s, b, h] output = self.dense_4h_to_h(intermediate_parallel) return output class GLMBlock(torch.nn.Module): """A single transformer layer. Transformer layer takes input with size [s, b, h] and returns an output of the same size. """ def __init__(self, config: ChatGLMConfig, layer_number, device=None): super(GLMBlock, self).__init__() self.layer_number = layer_number self.apply_residual_connection_post_layernorm = ( config.apply_residual_connection_post_layernorm ) self.fp32_residual_connection = config.fp32_residual_connection LayerNormFunc = RMSNorm # Layernorm on the input data. self.input_layernorm = LayerNormFunc( config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype, ) # Self attention. self.self_attention = SelfAttention(config, layer_number, device=device) self.hidden_dropout = config.hidden_dropout # Layernorm on the attention output self.post_attention_layernorm = LayerNormFunc( config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype, ) # MLP self.mlp = MLP(config, device=device) def forward(self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None): # hidden_states: [s, b, h] # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) # Self attention. attention_output, kv_cache = self.self_attention( layernorm_output, attention_mask, rotary_pos_emb, kv_cache=kv_cache ) residual = hidden_states layernorm_input = torch.nn.functional.dropout( attention_output, p=self.hidden_dropout, training=self.training ) layernorm_input = residual + layernorm_input # Layer norm post the self attention. layernorm_output = self.post_attention_layernorm(layernorm_input) # MLP. mlp_output = self.mlp(layernorm_output) residual = layernorm_input output = torch.nn.functional.dropout( mlp_output, p=self.hidden_dropout, training=self.training ) output = residual + output return output, kv_cache class GLMTransformer(torch.nn.Module): def __init__(self, config: ChatGLMConfig, device=None): super(GLMTransformer, self).__init__() self.fp32_residual_connection = config.fp32_residual_connection self.post_layer_norm = config.post_layer_norm self.num_layers = config.num_layers self.layers = [] for i in range(self.num_layers): self.layers.append(GLMBlock(config, i + 1, device=device)) self.layers = torch.nn.ModuleList(self.layers) self.final_layernorm = RMSNorm( config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype, ) def forward( self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None, use_cache: Optional[bool] = True, ): kv_caches = [None for _ in range(self.num_layers)] presents = () if use_cache else None for index in range(self.num_layers): layer = self.layers[index] hidden_states, kv_cache = layer( hidden_states, attention_mask, rotary_pos_emb, kv_cache=kv_caches[index] ) if use_cache: presents = presents + (kv_cache,) hidden_states = self.final_layernorm(hidden_states) return hidden_states class Embedding(torch.nn.Module): def __init__(self, config: ChatGLMConfig, device=None): super(Embedding, self).__init__() self.hidden_size = config.hidden_size self.word_embeddings = nn.Embedding( config.padded_vocab_size, self.hidden_size, dtype=config.torch_dtype, device=device, ) def forward(self, input_ids): # Embeddings. words_embeddings = self.word_embeddings(input_ids) embeddings = words_embeddings # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. embeddings = embeddings.transpose(0, 1).contiguous() # If the input flag for fp32 residual connection is set, convert for float. return embeddings class ChatGLMModel(nn.Module): def __init__(self, config: ChatGLMConfig, device=None, empty_init=True): super().__init__() init_method = skip_init init_kwargs = {} if device is not None: init_kwargs["device"] = device self.embedding = init_method(Embedding, config, **init_kwargs) self.num_layers = config.num_layers self.multi_query_group_num = config.multi_query_group_num self.kv_channels = config.kv_channels self.config = config # Rotary positional embeddings self.seq_length = config.seq_length rotary_dim = ( config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels ) self.rotary_pos_emb = RotaryEmbedding( rotary_dim // 2, original_impl=config.original_rope, device=device, dtype=config.torch_dtype, ) self.encoder = init_method(GLMTransformer, config, **init_kwargs) self.output_layer = init_method( nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False, dtype=config.torch_dtype, **init_kwargs, ) def forward( self, input_ids, position_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.BoolTensor] = None, full_attention_mask: Optional[torch.BoolTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, inputs_embeds: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, return_last_logit: Optional[bool] = False, ): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) batch_size, seq_length = input_ids.shape if inputs_embeds is None: inputs_embeds = self.embedding(input_ids) # Rotary positional embeddings rotary_pos_emb = self.rotary_pos_emb(self.seq_length) from tools import show show.DumpTensorToImage(rotary_pos_emb[:, :, 0], "plot.png", scale=0.1) if position_ids is not None: rotary_pos_emb = rotary_pos_emb[position_ids] else: rotary_pos_emb = rotary_pos_emb[None, :seq_length] rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() hidden_states = self.encoder( inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb, kv_caches=past_key_values, use_cache=use_cache, ) if return_last_logit: hidden_states = hidden_states[-1:] lm_logits = self.output_layer(hidden_states) lm_logits = lm_logits.transpose(0, 1).contiguous() return lm_logits class ChatGLMForConditionalGeneration(nn.Module): def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): super().__init__() self.max_sequence_length = config.max_length self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device) self.config = config self.main_input_name = "input_ids" self.config = config self.name_or_path = config.name_or_path self.warnings_issued = {} self.generation_config = GenerationConfig.from_model_config(config) def from_pretrained( cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, **kwargs, ): state_dict = kwargs.pop("state_dict", None) _ = kwargs.pop("mirror", None) load_in_8bit = kwargs.pop("load_in_8bit", False) load_in_4bit = kwargs.pop("load_in_4bit", False) subfolder = kwargs.pop("subfolder", "") # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the # index of the files. sharded_metadata = None # if pretrained_model_name_or_path is not None: pretrained_model_name_or_path = str(pretrained_model_name_or_path) archive_file = os.path.join( pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME, ) print(f"loading weights file {archive_file}") resolved_archive_file = archive_file with open(resolved_archive_file, "r") as f: index = json.loads(f.read()) shard_filenames = sorted(set(index["weight_map"].values())) sharded_metadata = index["metadata"] sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys()) sharded_metadata["weight_map"] = index["weight_map"].copy() resolved_archive_file = [ os.path.join(pretrained_model_name_or_path, subfolder, f) for f in shard_filenames ] loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] config.name_or_path = pretrained_model_name_or_path config = copy.deepcopy(config) # make sure we use the model's config since the __init__ call might have copied it config = cls.config # restore default dtype model = cls._load_pretrained_model( state_dict, loaded_state_dict_keys, resolved_archive_file, pretrained_model_name_or_path, ) model.is_loaded_in_4bit = load_in_4bit model.is_loaded_in_8bit = load_in_8bit # Set model in evaluation mode to deactivate DropOut modules by default model.eval() return model def _load_state_dict_into_model(self, model_to_load, state_dict, start_prefix): metadata = getattr(state_dict, "_metadata", None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata error_msgs = [] def load(module: nn.Module, state_dict, prefix=""): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) if len([key for key in state_dict if key.startswith(prefix)]) > 0: module._load_from_state_dict(*args) for name, child in module._modules.items(): if child is not None: load(child, state_dict, prefix + name + ".") load(model_to_load, state_dict, prefix=start_prefix) del state_dict return error_msgs def _load_pretrained_model( cls, state_dict, loaded_keys, resolved_archive_file, pretrained_model_name_or_path, ): # Retrieve missing & unexpected_keys model_state_dict = cls.state_dict() expected_keys = list(model_state_dict.keys()) loaded_keys = [key for key in loaded_keys] unexpected_keys = set(loaded_keys) - set(expected_keys) model_buffers = {n for n, _ in cls.named_buffers()} unexpected_keys = list(unexpected_keys - model_buffers) ptrs = collections.defaultdict(list) for name, tensor in cls.state_dict().items(): ptrs[(tensor.device, storage_ptr(tensor), storage_size(tensor))].append( name ) # retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights. _loaded_keys = loaded_keys for module_name, module in cls.named_modules(): loaded_keys = [ k.replace(f"{module_name}.", "") for k in _loaded_keys if k.startswith(f"{module_name}.") ] if len(set(module.state_dict().keys()) - set(loaded_keys)) == 0: module._is_hf_initialized = True # Make sure we are able to load base models as well as derived models (with heads) start_prefix = "" model_to_load = cls error_msgs = [] if len(resolved_archive_file) > 1: resolved_archive_file = tqdm_lib.tqdm( resolved_archive_file, desc="Loading checkpoint shards" ) for shard_file in resolved_archive_file: # Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload. state_dict = torch.load(shard_file, map_location="cpu") # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not # matching the weights in the model. error_msgs += cls._load_state_dict_into_model( model_to_load, state_dict, start_prefix ) # force memory release del state_dict gc.collect() print( f"All model checkpoint weights were used when initializing {cls.__class__.__name__}.\n" ) print( f"All the weights of {cls.__class__.__name__} were initialized from the model checkpoint at" f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint" f" was trained on, you can already use {cls.__class__.__name__} for predictions without further" " training." ) return cls @torch.inference_mode() def chat( self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = "user", ): if history is None: history = [] inputs = tokenizer.build_chat_input(query, history=history, role=role) inputs = inputs.to(next(self.parameters()).device) generation_config = copy.deepcopy(self.generation_config) inputs_tensor = inputs["input_ids"] input_ids = inputs_tensor.repeat_interleave( generation_config.num_return_sequences, dim=0 ) outputs = self.sample( input_ids, pad_token_id=generation_config.pad_token_id, eos_token_id=generation_config.eos_token_id, output_hidden_states=generation_config.output_hidden_states, use_cache=generation_config.use_cache, ) outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) : -1] response = tokenizer.decode(outputs) history.append({"role": role, "content": query}) return response, history def sample( self, input_ids: torch.LongTensor, pad_token_id: Optional[int] = None, eos_token_id: Optional[Union[int, List[int]]] = None, output_hidden_states: Optional[bool] = None, use_cache: Optional[bool] = None, ): if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) unfinished_sequences = torch.ones( input_ids.shape[0], dtype=torch.long, device=input_ids.device ) this_peer_finished = False # used by synced_gpus only while True: input_ids_in = input_ids batch_size, seq_length = input_ids_in.shape position_ids_in = ( torch.arange(seq_length, dtype=torch.long, device=input_ids.device) .unsqueeze(0) .repeat(batch_size, 1) ) use_cache = use_cache if use_cache is not None else self.config.use_cache model_inputs = { "input_ids": input_ids_in, "past_key_values": None, "position_ids": position_ids_in, "return_last_logit": True, "use_cache": use_cache, } logits = self.transformer( **model_inputs, return_dict=True, output_hidden_states=output_hidden_states, ) next_token_logits = logits[:, -1, :] next_token_scores = next_token_logits probs = nn.functional.softmax(next_token_scores, dim=-1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # finished sentences should have their next token be a padding token if eos_token_id is not None: next_tokens = next_tokens * unfinished_sequences + pad_token_id * ( 1 - unfinished_sequences ) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) # if eos_token was found in one sentence, set sentence to finished if eos_token_id_tensor is not None: unfinished_sequences = unfinished_sequences.mul( next_tokens.tile(eos_token_id_tensor.shape[0], 1) .ne(eos_token_id_tensor.unsqueeze(1)) .prod(dim=0) ) if unfinished_sequences.max() == 0: this_peer_finished = True if this_peer_finished: break return input_ids