import collections import math import copy import os import gc import json import hashlib import torch import torch.utils.checkpoint import torch.nn.functional as F from torch import nn from torch.nn.utils import skip_init from typing import Optional, Tuple, Union, List, Dict, Any from tqdm import auto as tqdm_lib from safetensors.torch import storage_ptr, storage_size from transformers.configuration_utils import PretrainedConfig from transformers.generation import GenerationConfig from configuration_chatglm import ChatGLMConfig class RotaryEmbedding(nn.Module): def __init__(self, dim: int, original_impl=False, device=None, dtype=None): super().__init__() inv_freq = 1.0 / ( 10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim) ) self.register_buffer("inv_freq", inv_freq) self.dim = dim self.original_impl = original_impl def forward(self, max_seq_len: int, base: int = 10000): dtype = self.inv_freq.dtype device = self.inv_freq.device # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ theta = 1.0 / ( base ** ( torch.arange(0, self.dim, 2, dtype=torch.float, device=device) / self.dim ) ) # Create position indexes `[0, 1, ..., max_seq_len - 1]` seq_idx = torch.arange(max_seq_len, dtype=torch.float, device=device) # Calculate the product of position index and $\theta_i$ idx_theta = torch.outer(seq_idx, theta).float() cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) # this is to mimic the behaviour of complex32, else we will get different results if dtype in (torch.float16, torch.bfloat16, torch.int8): cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half() return cache class RMSNorm(torch.nn.Module): def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): super().__init__() self.weight = torch.nn.Parameter( torch.empty(normalized_shape, device=device, dtype=dtype) ) self.eps = eps def forward(self, hidden_states: torch.Tensor): input_dtype = hidden_states.dtype variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.eps) return (self.weight * hidden_states).to(input_dtype) class CoreAttention(torch.nn.Module): def __init__(self, config: ChatGLMConfig, layer_number): super(CoreAttention, self).__init__() self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 self.layer_number = max(1, layer_number) projection_size = config.kv_channels * config.num_attention_heads # Per attention head and per partition values. self.hidden_size_per_partition = projection_size self.hidden_size_per_attention_head = ( projection_size // config.num_attention_heads ) self.num_attention_heads_per_partition = config.num_attention_heads coeff = None self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) coeff = self.layer_number self.norm_factor *= coeff self.coeff = coeff self.attention_dropout = torch.nn.Dropout(config.attention_dropout) def forward(self, query_layer, key_layer, value_layer): query_layer, key_layer, value_layer = [ k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer] ] if query_layer.shape[2] == key_layer.shape[2]: context_layer = torch.nn.functional.scaled_dot_product_attention( query_layer, key_layer, value_layer, is_causal=True ) context_layer = context_layer.permute(2, 0, 1, 3) new_context_layer_shape = context_layer.size()[:-2] + ( self.hidden_size_per_partition, ) context_layer = context_layer.reshape(*new_context_layer_shape) return context_layer class SelfAttention(torch.nn.Module): def __init__(self, config: ChatGLMConfig, layer_number, device=None): super(SelfAttention, self).__init__() self.layer_number = max(1, layer_number) self.projection_size = config.kv_channels * config.num_attention_heads self.hidden_size_per_attention_head = ( self.projection_size // config.num_attention_heads ) self.num_attention_heads_per_partition = config.num_attention_heads self.multi_query_attention = config.multi_query_attention self.qkv_hidden_size = 3 * self.projection_size self.num_multi_query_groups_per_partition = config.multi_query_group_num self.qkv_hidden_size = ( self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num ) self.query_key_value = nn.Linear( config.hidden_size, self.qkv_hidden_size, bias=config.add_bias_linear or config.add_qkv_bias, device=device, dtype=config.torch_dtype, ) self.core_attention = CoreAttention(config, self.layer_number) self.dense = nn.Linear( self.projection_size, config.hidden_size, bias=config.add_bias_linear, device=device, dtype=config.torch_dtype, ) def apply_rotary_pos_emb(self, x: torch.Tensor, rope: torch.Tensor) -> torch.Tensor: # x: [sq, b, np, hn] sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3) if rope.size(0) != sq: raise ("Error rotary_pos_emb size") x_rope = x[..., : hn // 2] x_pass = x[..., hn // 2 :] x_rope = x_rope.reshape(sq, -1, np, hn // 4, 1, 2) rope = rope.view(sq, -1, 1, hn // 4, 1, 2) roped1 = x_rope[..., 0] * rope[..., 0] - x_rope[..., 1] * rope[..., 1] roped2 = x_rope[..., 1] * rope[..., 0] + x_rope[..., 0] * rope[..., 1] x_out = torch.cat((roped1, roped2), -1) x_out = x_out.flatten(3) return torch.cat((x_out, x_pass), dim=-1) def forward(self, hidden_states, rotary_pos_emb): # hidden_states: [sq, b, h] # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] mixed_x_layer = self.query_key_value(hidden_states) (query_layer, key_layer, value_layer) = mixed_x_layer.split( [ self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, ], dim=-1, ) query_layer = query_layer.view( query_layer.size()[:-1] + ( self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, ) ) key_layer = key_layer.view( key_layer.size()[:-1] + ( self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head, ) ) value_layer = value_layer.view( value_layer.size()[:-1] + ( self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head, ) ) # apply relative positional encoding (rotary embedding) if rotary_pos_emb is not None: query_layer = self.apply_rotary_pos_emb(query_layer, rotary_pos_emb) key_layer = self.apply_rotary_pos_emb(key_layer, rotary_pos_emb) key_layer = key_layer.unsqueeze(-2) key_layer = key_layer.expand( -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1, ) key_layer = key_layer.contiguous().view( key_layer.size()[:2] + ( self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, ) ) value_layer = value_layer.unsqueeze(-2) value_layer = value_layer.expand( -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1, ) value_layer = value_layer.contiguous().view( value_layer.size()[:2] + ( self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, ) ) context_layer = self.core_attention(query_layer, key_layer, value_layer) output = self.dense(context_layer) # [sq, b, h] return output class MLP(torch.nn.Module): def __init__(self, config: ChatGLMConfig, device=None): super(MLP, self).__init__() self.add_bias = config.add_bias_linear # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf self.dense_h_to_4h = nn.Linear( config.hidden_size, config.ffn_hidden_size * 2, bias=self.add_bias, device=device, dtype=config.torch_dtype, ) def swiglu(x): x = torch.chunk(x, 2, dim=-1) return F.silu(x[0]) * x[1] self.activation_func = swiglu self.dense_4h_to_h = nn.Linear( config.ffn_hidden_size, config.hidden_size, bias=self.add_bias, device=device, dtype=config.torch_dtype, ) def forward(self, hidden_states): intermediate_parallel = self.dense_h_to_4h(hidden_states) intermediate_parallel = self.activation_func(intermediate_parallel) output = self.dense_4h_to_h(intermediate_parallel) return output class GLMBlock(torch.nn.Module): """A single transformer layer. Transformer layer takes input with size [s, b, h] and returns an output of the same size. """ def __init__(self, config: ChatGLMConfig, layer_number, device=None): super(GLMBlock, self).__init__() self.layer_number = layer_number self.apply_residual_connection_post_layernorm = ( config.apply_residual_connection_post_layernorm ) self.fp32_residual_connection = config.fp32_residual_connection LayerNormFunc = RMSNorm # Layernorm on the input data. self.input_layernorm = LayerNormFunc( config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype, ) # Self attention. self.self_attention = SelfAttention(config, layer_number, device=device) self.hidden_dropout = config.hidden_dropout # Layernorm on the attention output self.post_attention_layernorm = LayerNormFunc( config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype, ) self.mlp = MLP(config, device=device) def forward(self, hidden_states, rotary_pos_emb): # hidden_states: [s, b, h] # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) # Self attention. attention_output = self.self_attention(layernorm_output, rotary_pos_emb) residual = hidden_states layernorm_input = torch.nn.functional.dropout( attention_output, p=self.hidden_dropout, training=self.training ) layernorm_input = residual + layernorm_input # Layer norm post the self attention. layernorm_output = self.post_attention_layernorm(layernorm_input) # MLP. mlp_output = self.mlp(layernorm_output) residual = layernorm_input output = torch.nn.functional.dropout( mlp_output, p=self.hidden_dropout, training=self.training ) output = residual + output return output class GLMTransformer(torch.nn.Module): def __init__(self, config: ChatGLMConfig, device=None): super(GLMTransformer, self).__init__() self.fp32_residual_connection = config.fp32_residual_connection self.post_layer_norm = config.post_layer_norm self.num_layers = config.num_layers self.layers = [] for i in range(self.num_layers): self.layers.append(GLMBlock(config, i + 1, device=device)) self.layers = torch.nn.ModuleList(self.layers) self.final_layernorm = RMSNorm( config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype, ) def forward(self, hidden_states, rotary_pos_emb): for index in range(self.num_layers): layer = self.layers[index] hidden_states = layer(hidden_states, rotary_pos_emb) hidden_states = self.final_layernorm(hidden_states) return hidden_states class Embedding(torch.nn.Module): def __init__(self, config: ChatGLMConfig, device=None): super(Embedding, self).__init__() self.hidden_size = config.hidden_size self.word_embeddings = nn.Embedding( config.padded_vocab_size, self.hidden_size, dtype=config.torch_dtype, device=device, ) def forward(self, input_ids): # Embeddings. words_embeddings = self.word_embeddings(input_ids) embeddings = words_embeddings # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. embeddings = embeddings.transpose(0, 1).contiguous() # If the input flag for fp32 residual connection is set, convert for float. return embeddings class ChatGLMModel(nn.Module): def __init__(self, config: ChatGLMConfig, device=None, empty_init=True): super().__init__() init_method = skip_init init_kwargs = {} if device is not None: init_kwargs["device"] = device self.embedding = init_method(Embedding, config, **init_kwargs) self.num_layers = config.num_layers self.multi_query_group_num = config.multi_query_group_num self.kv_channels = config.kv_channels self.config = config # Rotary positional embeddings self.seq_length = config.seq_length rotary_dim = ( config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels ) self.rotary_pos_emb = RotaryEmbedding( rotary_dim // 2, original_impl=config.original_rope, device=device, dtype=config.torch_dtype, ) self.encoder = init_method(GLMTransformer, config, **init_kwargs) self.output_layer = init_method( nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False, dtype=config.torch_dtype, **init_kwargs, ) def forward( self, input_ids, position_ids: Optional[torch.Tensor] = None, output_hidden_states: Optional[bool] = None, tokenizer=None, ): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) inputs_embeds = self.embedding(input_ids) rotary_pos_emb = self.rotary_pos_emb(self.seq_length) rotary_pos_emb = rotary_pos_emb[position_ids] rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() hidden_states_en = self.encoder(inputs_embeds, rotary_pos_emb) hidden_states = hidden_states_en[-1:] lm_logits = self.output_layer(hidden_states) lm_logits = lm_logits.transpose(0, 1).contiguous() next_token_logits = lm_logits[:, -1, :] probs = nn.functional.softmax(next_token_logits, dim=-1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) return probs, next_tokens class ChatGLMForConditionalGeneration(nn.Module): def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): super().__init__() self.max_sequence_length = config.max_length self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device) self.config = config self.main_input_name = "input_ids" self.config = config self.name_or_path = config.name_or_path self.warnings_issued = {} self.generation_config = GenerationConfig.from_model_config(config) def from_pretrained( cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] ): load_in_8bit = False load_in_4bit = False pretrained_model_name_or_path = str(pretrained_model_name_or_path) resolved_archive_file = os.path.join( pretrained_model_name_or_path, "pytorch_model.bin.index.json" ) print(f"loading weights file {resolved_archive_file}") with open(resolved_archive_file, "r") as f: index = json.loads(f.read()) shard_filenames = sorted(set(index["weight_map"].values())) resolved_archive_file = [ os.path.join(pretrained_model_name_or_path, f) for f in shard_filenames ] model = cls._load_pretrained_model(resolved_archive_file) model.is_loaded_in_4bit = load_in_4bit model.is_loaded_in_8bit = load_in_8bit return model def _load_state_dict_into_model(self, model_to_load, state_dict, start_prefix): metadata = getattr(state_dict, "_metadata", None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata error_msgs = [] def load(module: nn.Module, state_dict, prefix=""): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) if len([key for key in state_dict if key.startswith(prefix)]) > 0: module._load_from_state_dict(*args) for name, child in module._modules.items(): if child is not None: load(child, state_dict, prefix + name + ".") load(model_to_load, state_dict, prefix=start_prefix) del state_dict return error_msgs def _load_pretrained_model(cls, resolved_archive_file): start_prefix = "" model_to_load = cls error_msgs = [] if len(resolved_archive_file) > 1: resolved_archive_file = tqdm_lib.tqdm( resolved_archive_file, desc="Loading checkpoint shards" ) for shard_file in resolved_archive_file: state_dict = torch.load(shard_file, map_location="cpu") error_msgs += cls._load_state_dict_into_model( model_to_load, state_dict, start_prefix ) del state_dict # force memory release gc.collect() print( f"All model checkpoint weights were used when initializing {cls.__class__.__name__}.\n" ) return cls @torch.inference_mode() def chat( self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = "user", ): if history is None: history = [] inputs = tokenizer.build_chat_input(query, history=history, role=role) inputs = inputs.to(next(self.parameters()).device) generation_config = copy.deepcopy(self.generation_config) inputs_tensor = inputs["input_ids"] input_ids = inputs_tensor.repeat_interleave( generation_config.num_return_sequences, dim=0 ) outputs = self.sample( input_ids, generation_config.pad_token_id, generation_config.eos_token_id, generation_config.output_hidden_states, tokenizer, ) outputs = outputs.tolist()[0][:] response = tokenizer.decode(outputs) history.append({"role": role, "content": query}) return response, history def sample( self, input_ids: torch.LongTensor, pad_token_id: Optional[int] = None, eos_token_id: Optional[Union[int, List[int]]] = None, output_hidden_states: Optional[bool] = None, tokenizer=None, ): if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) isFinished = torch.zeros( input_ids.shape[0], dtype=torch.long, device=input_ids.device ) # token_count = 0 while True: input_ids_in = input_ids batch_size, seq_length = input_ids_in.shape position_ids_in = ( torch.arange(seq_length, dtype=torch.long, device=input_ids.device) .unsqueeze(0) .repeat(batch_size, 1) ) model_inputs = {"input_ids": input_ids_in, "position_ids": position_ids_in} probs, next_tokens = self.transformer( **model_inputs, output_hidden_states=output_hidden_states, tokenizer=tokenizer, ) # finished sentences should add a padding token to next pad_token = pad_token_id * isFinished next_tokens = next_tokens * (1 - isFinished) + pad_token isFinished = isFinished | next_tokens.eq(eos_token_id_tensor) if isFinished.min() == 1: # all batch is finish break input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) return input_ids def backward( self, tokenizer, query: str, ): inputs = tokenizer.build_chat_input(query, history=[], role="user") inputs = inputs.to(next(self.parameters()).device) generation_config = copy.deepcopy(self.generation_config) inputs_tensor = inputs["input_ids"] input_ids = inputs_tensor.repeat_interleave( generation_config.num_return_sequences, dim=0 ) input_ids_in = input_ids batch_size, seq_length = input_ids_in.shape position_ids_in = ( torch.arange(seq_length, dtype=torch.long, device=input_ids.device) .unsqueeze(0) .repeat(batch_size, 1) ) model_inputs = {"input_ids": input_ids_in, "position_ids": position_ids_in} probs, next_tokens = self.transformer( **model_inputs, output_hidden_states=None, tokenizer=tokenizer, ) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # probs_target = probs # probs_target[0, next_tokens] = probs_target[0, next_tokens] * 1.1 loss = probs[0, next_tokens] loss.backward() return loss