commit a451def2997b3eca199b1a1f4e7960e1f9622092 Author: Colin Date: Thu Dec 21 16:53:47 2023 +0800 Init code. diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..7b62f12 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +__pycache__ +.vscode \ No newline at end of file diff --git a/chatglm/__init__.py b/chatglm/__init__.py new file mode 100644 index 0000000..75e2fe3 --- /dev/null +++ b/chatglm/__init__.py @@ -0,0 +1,7 @@ + +from chatglm.configuration_chatglm import ChatGLMConfig +from chatglm.tokenization_chatglm import ChatGLMTokenizer +from chatglm.modeling_chatglm import ChatGLMForConditionalGeneration + + + diff --git a/chatglm/configuration_chatglm.py b/chatglm/configuration_chatglm.py new file mode 100644 index 0000000..3560018 --- /dev/null +++ b/chatglm/configuration_chatglm.py @@ -0,0 +1,61 @@ +from transformers import PretrainedConfig + + +class ChatGLMConfig(PretrainedConfig): + model_type = "chatglm" + def __init__( + self, + num_layers=28, + padded_vocab_size=65024, + hidden_size=4096, + ffn_hidden_size=13696, + kv_channels=128, + num_attention_heads=32, + seq_length=2048, + hidden_dropout=0.0, + classifier_dropout=None, + attention_dropout=0.0, + layernorm_epsilon=1e-5, + rmsnorm=True, + apply_residual_connection_post_layernorm=False, + post_layer_norm=True, + add_bias_linear=False, + add_qkv_bias=False, + bias_dropout_fusion=True, + multi_query_attention=False, + multi_query_group_num=1, + apply_query_key_layer_scaling=True, + attention_softmax_in_fp32=True, + fp32_residual_connection=False, + quantization_bit=0, + pre_seq_len=None, + prefix_projection=False, + **kwargs + ): + self.num_layers = num_layers + self.vocab_size = padded_vocab_size + self.padded_vocab_size = padded_vocab_size + self.hidden_size = hidden_size + self.ffn_hidden_size = ffn_hidden_size + self.kv_channels = kv_channels + self.num_attention_heads = num_attention_heads + self.seq_length = seq_length + self.hidden_dropout = hidden_dropout + self.classifier_dropout = classifier_dropout + self.attention_dropout = attention_dropout + self.layernorm_epsilon = layernorm_epsilon + self.rmsnorm = rmsnorm + self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm + self.post_layer_norm = post_layer_norm + self.add_bias_linear = add_bias_linear + self.add_qkv_bias = add_qkv_bias + self.bias_dropout_fusion = bias_dropout_fusion + self.multi_query_attention = multi_query_attention + self.multi_query_group_num = multi_query_group_num + self.apply_query_key_layer_scaling = apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 + self.fp32_residual_connection = fp32_residual_connection + self.quantization_bit = quantization_bit + self.pre_seq_len = pre_seq_len + self.prefix_projection = prefix_projection + super().__init__(**kwargs) \ No newline at end of file diff --git a/chatglm/modeling_chatglm.py b/chatglm/modeling_chatglm.py new file mode 100644 index 0000000..98bced2 --- /dev/null +++ b/chatglm/modeling_chatglm.py @@ -0,0 +1,786 @@ +import collections +import math +import copy +import os +import gc +import json + +import torch +import torch.utils.checkpoint +import torch.nn.functional as F +from torch import nn +from torch.nn.utils import skip_init +from typing import Optional, Tuple, Union, List, Dict, Any + +from tqdm import auto as tqdm_lib +from safetensors.torch import storage_ptr, storage_size + +from transformers.configuration_utils import PretrainedConfig +from transformers.generation import GenerationConfig + +from chatglm import ChatGLMConfig + +WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json" + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim: int, original_impl=False, device=None, dtype=None): + super().__init__() + inv_freq = 1.0 / ( + 10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim) + ) + self.register_buffer("inv_freq", inv_freq) + self.dim = dim + self.original_impl = original_impl + + def forward(self, max_seq_len: int, base: int = 10000): + dtype = self.inv_freq.dtype + device = self.inv_freq.device + # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + theta = 1.0 / ( + base + ** ( + torch.arange(0, self.dim, 2, dtype=torch.float, device=device) + / self.dim + ) + ) + # Create position indexes `[0, 1, ..., max_seq_len - 1]` + seq_idx = torch.arange(max_seq_len, dtype=torch.float, device=device) + # Calculate the product of position index and $\theta_i$ + idx_theta = torch.outer(seq_idx, theta).float() + cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) + # this is to mimic the behaviour of complex32, else we will get different results + if dtype in (torch.float16, torch.bfloat16, torch.int8): + cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half() + return cache + + +class RMSNorm(torch.nn.Module): + def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): + super().__init__() + self.weight = torch.nn.Parameter( + torch.empty(normalized_shape, device=device, dtype=dtype) + ) + self.eps = eps + + def forward(self, hidden_states: torch.Tensor): + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + return (self.weight * hidden_states).to(input_dtype) + + +class CoreAttention(torch.nn.Module): + def __init__(self, config: ChatGLMConfig, layer_number): + super(CoreAttention, self).__init__() + + self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 + self.layer_number = max(1, layer_number) + projection_size = config.kv_channels * config.num_attention_heads + # Per attention head and per partition values. + self.hidden_size_per_partition = projection_size + self.hidden_size_per_attention_head = ( + projection_size // config.num_attention_heads + ) + self.num_attention_heads_per_partition = config.num_attention_heads + + coeff = None + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + coeff = self.layer_number + self.norm_factor *= coeff + self.coeff = coeff + + self.attention_dropout = torch.nn.Dropout(config.attention_dropout) + + def forward(self, query_layer, key_layer, value_layer, attention_mask): + query_layer, key_layer, value_layer = [ + k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer] + ] + if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, key_layer, value_layer, is_causal=True + ) + context_layer = context_layer.permute(2, 0, 1, 3) + new_context_layer_shape = context_layer.size()[:-2] + ( + self.hidden_size_per_partition, + ) + context_layer = context_layer.reshape(*new_context_layer_shape) + return context_layer + + +class SelfAttention(torch.nn.Module): + """Parallel self-attention layer abstract class. + + Self-attention layer takes input with size [s, b, h] + and returns output of the same size. + """ + + def __init__(self, config: ChatGLMConfig, layer_number, device=None): + super(SelfAttention, self).__init__() + self.layer_number = max(1, layer_number) + self.projection_size = config.kv_channels * config.num_attention_heads + self.hidden_size_per_attention_head = ( + self.projection_size // config.num_attention_heads + ) + self.num_attention_heads_per_partition = config.num_attention_heads + self.multi_query_attention = config.multi_query_attention + self.qkv_hidden_size = 3 * self.projection_size + self.num_multi_query_groups_per_partition = config.multi_query_group_num + self.qkv_hidden_size = ( + self.projection_size + + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num + ) + self.query_key_value = nn.Linear( + config.hidden_size, + self.qkv_hidden_size, + bias=config.add_bias_linear or config.add_qkv_bias, + device=device, + dtype=config.torch_dtype, + ) + self.core_attention = CoreAttention(config, self.layer_number) + self.dense = nn.Linear( + self.projection_size, + config.hidden_size, + bias=config.add_bias_linear, + device=device, + dtype=config.torch_dtype, + ) + + def apply_rotary_pos_emb( + self, x: torch.Tensor, rope_cache: torch.Tensor + ) -> torch.Tensor: + # x: [sq, b, np, hn] + sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3) + rot_dim = rope_cache.shape[-2] * 2 + x, x_pass = x[..., :rot_dim], x[..., rot_dim:] + # truncate to support variable sizes + rope_cache = rope_cache[:sq] + xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) + rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * rope_cache[..., 0] + - xshaped[..., 1] * rope_cache[..., 1], + xshaped[..., 1] * rope_cache[..., 0] + + xshaped[..., 0] * rope_cache[..., 1], + ], + -1, + ) + x_out2 = x_out2.flatten(3) + return torch.cat((x_out2, x_pass), dim=-1) + + def forward(self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None): + # hidden_states: [sq, b, h] + # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] + mixed_x_layer = self.query_key_value(hidden_states) + + (query_layer, key_layer, value_layer) = mixed_x_layer.split( + [ + self.num_attention_heads_per_partition + * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition + * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition + * self.hidden_size_per_attention_head, + ], + dim=-1, + ) + query_layer = query_layer.view( + query_layer.size()[:-1] + + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + ) + key_layer = key_layer.view( + key_layer.size()[:-1] + + ( + self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head, + ) + ) + value_layer = value_layer.view( + value_layer.size()[:-1] + + ( + self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head, + ) + ) + + # apply relative positional encoding (rotary embedding) + if rotary_pos_emb is not None: + query_layer = self.apply_rotary_pos_emb(query_layer, rotary_pos_emb) + key_layer = self.apply_rotary_pos_emb(key_layer, rotary_pos_emb) + + kv_cache = (key_layer, value_layer) + + key_layer = key_layer.unsqueeze(-2) + key_layer = key_layer.expand( + -1, + -1, + -1, + self.num_attention_heads_per_partition + // self.num_multi_query_groups_per_partition, + -1, + ) + key_layer = key_layer.contiguous().view( + key_layer.size()[:2] + + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + ) + value_layer = value_layer.unsqueeze(-2) + value_layer = value_layer.expand( + -1, + -1, + -1, + self.num_attention_heads_per_partition + // self.num_multi_query_groups_per_partition, + -1, + ) + value_layer = value_layer.contiguous().view( + value_layer.size()[:2] + + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + ) + # ================================== + # core attention computation + # ================================== + context_layer = self.core_attention( + query_layer, key_layer, value_layer, attention_mask + ) + # ================= + # Output. [sq, b, h] + # ================= + output = self.dense(context_layer) + return output, kv_cache + + +class MLP(torch.nn.Module): + """MLP. + + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. + """ + + def __init__(self, config: ChatGLMConfig, device=None): + super(MLP, self).__init__() + + self.add_bias = config.add_bias_linear + + # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf + self.dense_h_to_4h = nn.Linear( + config.hidden_size, + config.ffn_hidden_size * 2, + bias=self.add_bias, + device=device, + dtype=config.torch_dtype, + ) + + def swiglu(x): + x = torch.chunk(x, 2, dim=-1) + return F.silu(x[0]) * x[1] + + self.activation_func = swiglu + + # Project back to h. + self.dense_4h_to_h = nn.Linear( + config.ffn_hidden_size, + config.hidden_size, + bias=self.add_bias, + device=device, + dtype=config.torch_dtype, + ) + + def forward(self, hidden_states): + # [s, b, 4hp] + intermediate_parallel = self.dense_h_to_4h(hidden_states) + intermediate_parallel = self.activation_func(intermediate_parallel) + # [s, b, h] + output = self.dense_4h_to_h(intermediate_parallel) + return output + + +class GLMBlock(torch.nn.Module): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + """ + + def __init__(self, config: ChatGLMConfig, layer_number, device=None): + super(GLMBlock, self).__init__() + self.layer_number = layer_number + + self.apply_residual_connection_post_layernorm = ( + config.apply_residual_connection_post_layernorm + ) + + self.fp32_residual_connection = config.fp32_residual_connection + + LayerNormFunc = RMSNorm + # Layernorm on the input data. + self.input_layernorm = LayerNormFunc( + config.hidden_size, + eps=config.layernorm_epsilon, + device=device, + dtype=config.torch_dtype, + ) + # Self attention. + self.self_attention = SelfAttention(config, layer_number, device=device) + self.hidden_dropout = config.hidden_dropout + # Layernorm on the attention output + self.post_attention_layernorm = LayerNormFunc( + config.hidden_size, + eps=config.layernorm_epsilon, + device=device, + dtype=config.torch_dtype, + ) + # MLP + self.mlp = MLP(config, device=device) + + def forward(self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None): + # hidden_states: [s, b, h] + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output, kv_cache = self.self_attention( + layernorm_output, attention_mask, rotary_pos_emb, kv_cache=kv_cache + ) + residual = hidden_states + + layernorm_input = torch.nn.functional.dropout( + attention_output, p=self.hidden_dropout, training=self.training + ) + layernorm_input = residual + layernorm_input + + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + + # MLP. + mlp_output = self.mlp(layernorm_output) + + residual = layernorm_input + + output = torch.nn.functional.dropout( + mlp_output, p=self.hidden_dropout, training=self.training + ) + output = residual + output + return output, kv_cache + + +class GLMTransformer(torch.nn.Module): + def __init__(self, config: ChatGLMConfig, device=None): + super(GLMTransformer, self).__init__() + self.fp32_residual_connection = config.fp32_residual_connection + self.post_layer_norm = config.post_layer_norm + self.num_layers = config.num_layers + self.layers = [] + for i in range(self.num_layers): + self.layers.append(GLMBlock(config, i + 1, device=device)) + self.layers = torch.nn.ModuleList(self.layers) + self.final_layernorm = RMSNorm( + config.hidden_size, + eps=config.layernorm_epsilon, + device=device, + dtype=config.torch_dtype, + ) + + def forward( + self, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_caches=None, + use_cache: Optional[bool] = True, + ): + kv_caches = [None for _ in range(self.num_layers)] + presents = () if use_cache else None + + for index in range(self.num_layers): + layer = self.layers[index] + hidden_states, kv_cache = layer( + hidden_states, attention_mask, rotary_pos_emb, kv_cache=kv_caches[index] + ) + if use_cache: + presents = presents + (kv_cache,) + hidden_states = self.final_layernorm(hidden_states) + return hidden_states + + +class Embedding(torch.nn.Module): + def __init__(self, config: ChatGLMConfig, device=None): + super(Embedding, self).__init__() + + self.hidden_size = config.hidden_size + self.word_embeddings = nn.Embedding( + config.padded_vocab_size, + self.hidden_size, + dtype=config.torch_dtype, + device=device, + ) + + def forward(self, input_ids): + # Embeddings. + words_embeddings = self.word_embeddings(input_ids) + embeddings = words_embeddings + # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. + embeddings = embeddings.transpose(0, 1).contiguous() + # If the input flag for fp32 residual connection is set, convert for float. + return embeddings + + +class ChatGLMModel(nn.Module): + def __init__(self, config: ChatGLMConfig, device=None, empty_init=True): + super().__init__() + init_method = skip_init + init_kwargs = {} + if device is not None: + init_kwargs["device"] = device + self.embedding = init_method(Embedding, config, **init_kwargs) + self.num_layers = config.num_layers + self.multi_query_group_num = config.multi_query_group_num + self.kv_channels = config.kv_channels + self.config = config + + # Rotary positional embeddings + self.seq_length = config.seq_length + rotary_dim = ( + config.hidden_size // config.num_attention_heads + if config.kv_channels is None + else config.kv_channels + ) + + self.rotary_pos_emb = RotaryEmbedding( + rotary_dim // 2, + original_impl=config.original_rope, + device=device, + dtype=config.torch_dtype, + ) + self.encoder = init_method(GLMTransformer, config, **init_kwargs) + self.output_layer = init_method( + nn.Linear, + config.hidden_size, + config.padded_vocab_size, + bias=False, + dtype=config.torch_dtype, + **init_kwargs, + ) + + def forward( + self, + input_ids, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + full_attention_mask: Optional[torch.BoolTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + return_last_logit: Optional[bool] = False, + ): + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + batch_size, seq_length = input_ids.shape + + if inputs_embeds is None: + inputs_embeds = self.embedding(input_ids) + + # Rotary positional embeddings + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + + + import plotly_express as px + img = px.imshow((rotary_pos_emb[:,:,0]*256).byte().cpu()) + img.write_image("plot.png") + + + + + if position_ids is not None: + rotary_pos_emb = rotary_pos_emb[position_ids] + else: + rotary_pos_emb = rotary_pos_emb[None, :seq_length] + rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() + hidden_states = self.encoder( + inputs_embeds, + full_attention_mask, + rotary_pos_emb=rotary_pos_emb, + kv_caches=past_key_values, + use_cache=use_cache, + ) + if return_last_logit: + hidden_states = hidden_states[-1:] + lm_logits = self.output_layer(hidden_states) + lm_logits = lm_logits.transpose(0, 1).contiguous() + + return lm_logits + + +class ChatGLMForConditionalGeneration(nn.Module): + def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): + super().__init__() + + self.max_sequence_length = config.max_length + self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device) + self.config = config + + self.main_input_name = "input_ids" + self.config = config + self.name_or_path = config.name_or_path + self.warnings_issued = {} + self.generation_config = GenerationConfig.from_model_config(config) + + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, + **kwargs, + ): + state_dict = kwargs.pop("state_dict", None) + _ = kwargs.pop("mirror", None) + load_in_8bit = kwargs.pop("load_in_8bit", False) + load_in_4bit = kwargs.pop("load_in_4bit", False) + subfolder = kwargs.pop("subfolder", "") + + # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the + # index of the files. + sharded_metadata = None + + # if pretrained_model_name_or_path is not None: + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + archive_file = os.path.join( + pretrained_model_name_or_path, + subfolder, + WEIGHTS_INDEX_NAME, + ) + print(f"loading weights file {archive_file}") + resolved_archive_file = archive_file + + with open(resolved_archive_file, "r") as f: + index = json.loads(f.read()) + + shard_filenames = sorted(set(index["weight_map"].values())) + sharded_metadata = index["metadata"] + sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys()) + sharded_metadata["weight_map"] = index["weight_map"].copy() + + resolved_archive_file = [ + os.path.join(pretrained_model_name_or_path, subfolder, f) + for f in shard_filenames + ] + + loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] + config.name_or_path = pretrained_model_name_or_path + config = copy.deepcopy(config) + # make sure we use the model's config since the __init__ call might have copied it + config = cls.config + # restore default dtype + model = cls._load_pretrained_model( + state_dict, + loaded_state_dict_keys, + resolved_archive_file, + pretrained_model_name_or_path, + ) + model.is_loaded_in_4bit = load_in_4bit + model.is_loaded_in_8bit = load_in_8bit + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + return model + + def _load_state_dict_into_model(self, model_to_load, state_dict, start_prefix): + metadata = getattr(state_dict, "_metadata", None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + error_msgs = [] + + def load(module: nn.Module, state_dict, prefix=""): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) + if len([key for key in state_dict if key.startswith(prefix)]) > 0: + module._load_from_state_dict(*args) + + for name, child in module._modules.items(): + if child is not None: + load(child, state_dict, prefix + name + ".") + + load(model_to_load, state_dict, prefix=start_prefix) + del state_dict + return error_msgs + + def _load_pretrained_model( + cls, + state_dict, + loaded_keys, + resolved_archive_file, + pretrained_model_name_or_path, + ): + # Retrieve missing & unexpected_keys + model_state_dict = cls.state_dict() + expected_keys = list(model_state_dict.keys()) + loaded_keys = [key for key in loaded_keys] + + unexpected_keys = set(loaded_keys) - set(expected_keys) + model_buffers = {n for n, _ in cls.named_buffers()} + unexpected_keys = list(unexpected_keys - model_buffers) + + ptrs = collections.defaultdict(list) + for name, tensor in cls.state_dict().items(): + ptrs[(tensor.device, storage_ptr(tensor), storage_size(tensor))].append( + name + ) + # retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights. + + _loaded_keys = loaded_keys + for module_name, module in cls.named_modules(): + loaded_keys = [ + k.replace(f"{module_name}.", "") + for k in _loaded_keys + if k.startswith(f"{module_name}.") + ] + if len(set(module.state_dict().keys()) - set(loaded_keys)) == 0: + module._is_hf_initialized = True + + # Make sure we are able to load base models as well as derived models (with heads) + start_prefix = "" + model_to_load = cls + error_msgs = [] + if len(resolved_archive_file) > 1: + resolved_archive_file = tqdm_lib.tqdm( + resolved_archive_file, desc="Loading checkpoint shards" + ) + for shard_file in resolved_archive_file: + # Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload. + state_dict = torch.load(shard_file, map_location="cpu") + # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not + # matching the weights in the model. + error_msgs += cls._load_state_dict_into_model( + model_to_load, state_dict, start_prefix + ) + # force memory release + del state_dict + gc.collect() + + print( + f"All model checkpoint weights were used when initializing {cls.__class__.__name__}.\n" + ) + print( + f"All the weights of {cls.__class__.__name__} were initialized from the model checkpoint at" + f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint" + f" was trained on, you can already use {cls.__class__.__name__} for predictions without further" + " training." + ) + return cls + + @torch.inference_mode() + def chat( + self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = None, + role: str = "user", + ): + if history is None: + history = [] + inputs = tokenizer.build_chat_input(query, history=history, role=role) + inputs = inputs.to(next(self.parameters()).device) + + generation_config = copy.deepcopy(self.generation_config) + inputs_tensor = inputs["input_ids"] + input_ids = inputs_tensor.repeat_interleave( + generation_config.num_return_sequences, dim=0 + ) + + outputs = self.sample( + input_ids, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_hidden_states = generation_config.output_hidden_states, + use_cache = generation_config.use_cache + ) + + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) : -1] + response = tokenizer.decode(outputs) + history.append({"role": role, "content": query}) + return response, history + + def sample( + self, + input_ids: torch.LongTensor, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, + output_hidden_states: Optional[bool] = None, + use_cache: Optional[bool] = None + ): + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) + + unfinished_sequences = torch.ones( + input_ids.shape[0], dtype=torch.long, device=input_ids.device + ) + + this_peer_finished = False # used by synced_gpus only + while True: + input_ids_in = input_ids + batch_size, seq_length = input_ids_in.shape + position_ids_in = ( + torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + .unsqueeze(0) + .repeat(batch_size, 1) + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + model_inputs = { + "input_ids": input_ids_in, + "past_key_values": None, + "position_ids": position_ids_in, + "return_last_logit": True, + "use_cache": use_cache, + } + + logits = self.transformer( + **model_inputs, + return_dict=True, + output_hidden_states=output_hidden_states, + ) + next_token_logits = logits[:, -1, :] + next_token_scores = next_token_logits + probs = nn.functional.softmax(next_token_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + next_tokens = next_tokens * unfinished_sequences + pad_token_id * ( + 1 - unfinished_sequences + ) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id_tensor is not None: + unfinished_sequences = unfinished_sequences.mul( + next_tokens.tile(eos_token_id_tensor.shape[0], 1) + .ne(eos_token_id_tensor.unsqueeze(1)) + .prod(dim=0) + ) + if unfinished_sequences.max() == 0: + this_peer_finished = True + if this_peer_finished: + break + + return input_ids diff --git a/chatglm/tokenization_chatglm.py b/chatglm/tokenization_chatglm.py new file mode 100644 index 0000000..f62a1f9 --- /dev/null +++ b/chatglm/tokenization_chatglm.py @@ -0,0 +1,161 @@ +import json +import os +import torch +from typing import List, Optional, Union, Dict +from sentencepiece import SentencePieceProcessor +from transformers import PreTrainedTokenizer + + +class SPTokenizer: + def __init__(self, model_path: str): + # reload tokenizer + assert os.path.isfile(model_path), model_path + self.sp_model = SentencePieceProcessor(model_file=model_path) + + # BOS / EOS token IDs + self.n_words: int = self.sp_model.vocab_size() + self.bos_id: int = self.sp_model.bos_id() + self.eos_id: int = self.sp_model.eos_id() + self.pad_id: int = self.sp_model.unk_id() + assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() + + special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop", "<|system|>", "<|user|>", "<|assistant|>", + "<|observation|>"] + self.special_tokens = {} + self.index_special_tokens = {} + for token in special_tokens: + self.special_tokens[token] = self.n_words + self.index_special_tokens[self.n_words] = token + self.n_words += 1 + + def tokenize(self, s: str): + return self.sp_model.EncodeAsPieces(s) + + def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]: + assert type(s) is str + t = self.sp_model.encode(s) + if bos: + t = [self.bos_id] + t + if eos: + t = t + [self.eos_id] + return t + + def decode(self, t: List[int]) -> str: + text, buffer = "", [] + for token in t: + if token in self.index_special_tokens: + if buffer: + text += self.sp_model.decode(buffer) + buffer = [] + text += self.index_special_tokens[token] + else: + buffer.append(token) + if buffer: + text += self.sp_model.decode(buffer) + return text + + def decode_tokens(self, tokens: List[str]) -> str: + text = self.sp_model.DecodePieces(tokens) + return text + + def convert_token_to_id(self, token): + """ Converts a token (str) in an id using the vocab. """ + if token in self.special_tokens: + return self.special_tokens[token] + return self.sp_model.PieceToId(token) + + def convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + if index in self.index_special_tokens: + return self.index_special_tokens[index] + if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0: + return "" + return self.sp_model.IdToPiece(index) + + +class ChatGLMTokenizer(PreTrainedTokenizer): + vocab_files_names = {"vocab_file": "tokenizer.model"} + + model_input_names = ["input_ids", "attention_mask", "position_ids"] + + def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, **kwargs): + self.name = "GLMTokenizer" + + self.vocab_file = vocab_file + self.tokenizer = SPTokenizer(vocab_file) + self.special_tokens = { + "": self.tokenizer.bos_id, + "": self.tokenizer.eos_id, + "": self.tokenizer.pad_id + } + super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces, **kwargs) + + def get_command(self, token): + if token in self.special_tokens: + return self.special_tokens[token] + assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}" + return self.tokenizer.special_tokens[token] + + @property + def unk_token(self) -> str: + return "" + + @property + def pad_token(self) -> str: + return "" + + @property + def pad_token_id(self): + return self.get_command("") + + @property + def eos_token(self) -> str: + return "" + + @property + def eos_token_id(self): + return self.get_command("") + + @property + def vocab_size(self): + return self.tokenizer.n_words + + def get_vocab(self): + """ Returns vocab as a dict """ + vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text, **kwargs): + return self.tokenizer.tokenize(text) + + def _convert_token_to_id(self, token): + """ Converts a token (str) in an id using the vocab. """ + return self.tokenizer.convert_token_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.tokenizer.convert_id_to_token(index) + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + return self.tokenizer.decode_tokens(tokens) + + def build_single_message(self, role, metadata, message): + assert role in ["system", "user", "assistant", "observation"], role + role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n") + message_tokens = self.tokenizer.encode(message) + tokens = role_tokens + message_tokens + return tokens + + def build_chat_input(self, query, history=None, role="user"): + if history is None: + history = [] + input_ids = [] + for item in history: + content = item["content"] + if item["role"] == "system" and "tools" in item: + content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False) + input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content)) + input_ids.extend(self.build_single_message(role, "", query)) + input_ids.extend([self.get_command("<|assistant|>")]) + return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True) diff --git a/chatglm/tokenizer.model b/chatglm/tokenizer.model new file mode 100644 index 0000000..c8336ad Binary files /dev/null and b/chatglm/tokenizer.model differ diff --git a/chatglm/tokenizer_config.json b/chatglm/tokenizer_config.json new file mode 100644 index 0000000..16882eb --- /dev/null +++ b/chatglm/tokenizer_config.json @@ -0,0 +1,12 @@ +{ + "name_or_path": "THUDM/chatglm3-6b", + "remove_space": false, + "do_lower_case": false, + "tokenizer_class": "ChatGLMTokenizer", + "auto_map": { + "AutoTokenizer": [ + "tokenization_chatglm.ChatGLMTokenizer", + null + ] + } +} diff --git a/demo.py b/demo.py new file mode 100644 index 0000000..f1e9985 --- /dev/null +++ b/demo.py @@ -0,0 +1,64 @@ +import json + + +from chatglm import ChatGLMForConditionalGeneration +from chatglm import ChatGLMTokenizer + +from transformers import AutoConfig + +pretrained_model_name_or_path = "../ZhipuAI/chatglm3-6b" +config, kwargs = AutoConfig.from_pretrained( + pretrained_model_name_or_path, + return_unused_kwargs=True, + trust_remote_code=True, + code_revision=None, + _commit_hash=None, +) +glm = ChatGLMForConditionalGeneration(config) + + +tokenizer_config_file = "./chatglm/tokenizer_config.json" +if tokenizer_config_file is not None: + with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle: + init_kwargs = json.load(tokenizer_config_handle) + init_kwargs.pop("tokenizer_class", None) + init_kwargs.pop("tokenizer_file", None) + saved_init_inputs = init_kwargs.pop("init_inputs", ()) + init_inputs = saved_init_inputs +init_kwargs["vocab_file"] = './chatglm/tokenizer.model' +init_kwargs["added_tokens_file"] = None +init_kwargs["special_tokens_map_file"] = None +init_kwargs["tokenizer_file"] = None +init_kwargs["name_or_path"] = pretrained_model_name_or_path +tokenizer = ChatGLMTokenizer(*init_inputs, **init_kwargs) + + +glm = glm.from_pretrained(pretrained_model_name_or_path, config=config).half().cuda() +glm = glm.eval() +response, history = glm.chat(tokenizer, "colin", history=[]) +print(response) +response, history = glm.chat(tokenizer, "你好", history=history) +print(response) +# response, history = glm.chat(tokenizer, "你是一个心理学专家,请问晚上睡不着应该怎么办", history=history) +# print(response) + + +# import plotly_express as px +# px.imshow(ron) +# gapminder = px.data.gapminder() +# gapminder2007 = gapminder.query('year == 2007') +# px.scatter(gapminder2007, x='gdpPercap', y='lifeExp') + + + +# from modelscope import AutoTokenizer, AutoModel, snapshot_download +# model_dir = snapshot_download("ZhipuAI/chatglm3-6b", cache_dir="./chatglm", revision="v1.0.0") +# model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).half().cuda() +# tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) +# model = model.eval() +# response, history = model.chat(tokenizer, "colin", history=[]) +# print(response) +# response, history = model.chat(tokenizer, "你好", history=history) +# print(response) +# # response, history = model.chat(tokenizer, "你是一个心理学专家,请问晚上睡不着应该怎么办", history=history) +# # print(response) diff --git a/plot.png b/plot.png new file mode 100644 index 0000000..5272237 Binary files /dev/null and b/plot.png differ