Init code.
This commit is contained in:
commit
a451def299
|
@ -0,0 +1,2 @@
|
||||||
|
__pycache__
|
||||||
|
.vscode
|
|
@ -0,0 +1,7 @@
|
||||||
|
|
||||||
|
from chatglm.configuration_chatglm import ChatGLMConfig
|
||||||
|
from chatglm.tokenization_chatglm import ChatGLMTokenizer
|
||||||
|
from chatglm.modeling_chatglm import ChatGLMForConditionalGeneration
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,61 @@
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
class ChatGLMConfig(PretrainedConfig):
|
||||||
|
model_type = "chatglm"
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_layers=28,
|
||||||
|
padded_vocab_size=65024,
|
||||||
|
hidden_size=4096,
|
||||||
|
ffn_hidden_size=13696,
|
||||||
|
kv_channels=128,
|
||||||
|
num_attention_heads=32,
|
||||||
|
seq_length=2048,
|
||||||
|
hidden_dropout=0.0,
|
||||||
|
classifier_dropout=None,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
layernorm_epsilon=1e-5,
|
||||||
|
rmsnorm=True,
|
||||||
|
apply_residual_connection_post_layernorm=False,
|
||||||
|
post_layer_norm=True,
|
||||||
|
add_bias_linear=False,
|
||||||
|
add_qkv_bias=False,
|
||||||
|
bias_dropout_fusion=True,
|
||||||
|
multi_query_attention=False,
|
||||||
|
multi_query_group_num=1,
|
||||||
|
apply_query_key_layer_scaling=True,
|
||||||
|
attention_softmax_in_fp32=True,
|
||||||
|
fp32_residual_connection=False,
|
||||||
|
quantization_bit=0,
|
||||||
|
pre_seq_len=None,
|
||||||
|
prefix_projection=False,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.vocab_size = padded_vocab_size
|
||||||
|
self.padded_vocab_size = padded_vocab_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.ffn_hidden_size = ffn_hidden_size
|
||||||
|
self.kv_channels = kv_channels
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.seq_length = seq_length
|
||||||
|
self.hidden_dropout = hidden_dropout
|
||||||
|
self.classifier_dropout = classifier_dropout
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.layernorm_epsilon = layernorm_epsilon
|
||||||
|
self.rmsnorm = rmsnorm
|
||||||
|
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
|
||||||
|
self.post_layer_norm = post_layer_norm
|
||||||
|
self.add_bias_linear = add_bias_linear
|
||||||
|
self.add_qkv_bias = add_qkv_bias
|
||||||
|
self.bias_dropout_fusion = bias_dropout_fusion
|
||||||
|
self.multi_query_attention = multi_query_attention
|
||||||
|
self.multi_query_group_num = multi_query_group_num
|
||||||
|
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
|
||||||
|
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
|
||||||
|
self.fp32_residual_connection = fp32_residual_connection
|
||||||
|
self.quantization_bit = quantization_bit
|
||||||
|
self.pre_seq_len = pre_seq_len
|
||||||
|
self.prefix_projection = prefix_projection
|
||||||
|
super().__init__(**kwargs)
|
|
@ -0,0 +1,786 @@
|
||||||
|
import collections
|
||||||
|
import math
|
||||||
|
import copy
|
||||||
|
import os
|
||||||
|
import gc
|
||||||
|
import json
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn.utils import skip_init
|
||||||
|
from typing import Optional, Tuple, Union, List, Dict, Any
|
||||||
|
|
||||||
|
from tqdm import auto as tqdm_lib
|
||||||
|
from safetensors.torch import storage_ptr, storage_size
|
||||||
|
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
from transformers.generation import GenerationConfig
|
||||||
|
|
||||||
|
from chatglm import ChatGLMConfig
|
||||||
|
|
||||||
|
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
|
||||||
|
|
||||||
|
|
||||||
|
class RotaryEmbedding(nn.Module):
|
||||||
|
def __init__(self, dim: int, original_impl=False, device=None, dtype=None):
|
||||||
|
super().__init__()
|
||||||
|
inv_freq = 1.0 / (
|
||||||
|
10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim)
|
||||||
|
)
|
||||||
|
self.register_buffer("inv_freq", inv_freq)
|
||||||
|
self.dim = dim
|
||||||
|
self.original_impl = original_impl
|
||||||
|
|
||||||
|
def forward(self, max_seq_len: int, base: int = 10000):
|
||||||
|
dtype = self.inv_freq.dtype
|
||||||
|
device = self.inv_freq.device
|
||||||
|
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
|
||||||
|
theta = 1.0 / (
|
||||||
|
base
|
||||||
|
** (
|
||||||
|
torch.arange(0, self.dim, 2, dtype=torch.float, device=device)
|
||||||
|
/ self.dim
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Create position indexes `[0, 1, ..., max_seq_len - 1]`
|
||||||
|
seq_idx = torch.arange(max_seq_len, dtype=torch.float, device=device)
|
||||||
|
# Calculate the product of position index and $\theta_i$
|
||||||
|
idx_theta = torch.outer(seq_idx, theta).float()
|
||||||
|
cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
|
||||||
|
# this is to mimic the behaviour of complex32, else we will get different results
|
||||||
|
if dtype in (torch.float16, torch.bfloat16, torch.int8):
|
||||||
|
cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
|
||||||
|
return cache
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(torch.nn.Module):
|
||||||
|
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = torch.nn.Parameter(
|
||||||
|
torch.empty(normalized_shape, device=device, dtype=dtype)
|
||||||
|
)
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor):
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||||
|
return (self.weight * hidden_states).to(input_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class CoreAttention(torch.nn.Module):
|
||||||
|
def __init__(self, config: ChatGLMConfig, layer_number):
|
||||||
|
super(CoreAttention, self).__init__()
|
||||||
|
|
||||||
|
self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
|
||||||
|
self.layer_number = max(1, layer_number)
|
||||||
|
projection_size = config.kv_channels * config.num_attention_heads
|
||||||
|
# Per attention head and per partition values.
|
||||||
|
self.hidden_size_per_partition = projection_size
|
||||||
|
self.hidden_size_per_attention_head = (
|
||||||
|
projection_size // config.num_attention_heads
|
||||||
|
)
|
||||||
|
self.num_attention_heads_per_partition = config.num_attention_heads
|
||||||
|
|
||||||
|
coeff = None
|
||||||
|
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
|
||||||
|
coeff = self.layer_number
|
||||||
|
self.norm_factor *= coeff
|
||||||
|
self.coeff = coeff
|
||||||
|
|
||||||
|
self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
|
||||||
|
|
||||||
|
def forward(self, query_layer, key_layer, value_layer, attention_mask):
|
||||||
|
query_layer, key_layer, value_layer = [
|
||||||
|
k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]
|
||||||
|
]
|
||||||
|
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
|
||||||
|
context_layer = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
query_layer, key_layer, value_layer, is_causal=True
|
||||||
|
)
|
||||||
|
context_layer = context_layer.permute(2, 0, 1, 3)
|
||||||
|
new_context_layer_shape = context_layer.size()[:-2] + (
|
||||||
|
self.hidden_size_per_partition,
|
||||||
|
)
|
||||||
|
context_layer = context_layer.reshape(*new_context_layer_shape)
|
||||||
|
return context_layer
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttention(torch.nn.Module):
|
||||||
|
"""Parallel self-attention layer abstract class.
|
||||||
|
|
||||||
|
Self-attention layer takes input with size [s, b, h]
|
||||||
|
and returns output of the same size.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: ChatGLMConfig, layer_number, device=None):
|
||||||
|
super(SelfAttention, self).__init__()
|
||||||
|
self.layer_number = max(1, layer_number)
|
||||||
|
self.projection_size = config.kv_channels * config.num_attention_heads
|
||||||
|
self.hidden_size_per_attention_head = (
|
||||||
|
self.projection_size // config.num_attention_heads
|
||||||
|
)
|
||||||
|
self.num_attention_heads_per_partition = config.num_attention_heads
|
||||||
|
self.multi_query_attention = config.multi_query_attention
|
||||||
|
self.qkv_hidden_size = 3 * self.projection_size
|
||||||
|
self.num_multi_query_groups_per_partition = config.multi_query_group_num
|
||||||
|
self.qkv_hidden_size = (
|
||||||
|
self.projection_size
|
||||||
|
+ 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
|
||||||
|
)
|
||||||
|
self.query_key_value = nn.Linear(
|
||||||
|
config.hidden_size,
|
||||||
|
self.qkv_hidden_size,
|
||||||
|
bias=config.add_bias_linear or config.add_qkv_bias,
|
||||||
|
device=device,
|
||||||
|
dtype=config.torch_dtype,
|
||||||
|
)
|
||||||
|
self.core_attention = CoreAttention(config, self.layer_number)
|
||||||
|
self.dense = nn.Linear(
|
||||||
|
self.projection_size,
|
||||||
|
config.hidden_size,
|
||||||
|
bias=config.add_bias_linear,
|
||||||
|
device=device,
|
||||||
|
dtype=config.torch_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb(
|
||||||
|
self, x: torch.Tensor, rope_cache: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# x: [sq, b, np, hn]
|
||||||
|
sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
|
||||||
|
rot_dim = rope_cache.shape[-2] * 2
|
||||||
|
x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
|
||||||
|
# truncate to support variable sizes
|
||||||
|
rope_cache = rope_cache[:sq]
|
||||||
|
xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
|
||||||
|
rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
|
||||||
|
x_out2 = torch.stack(
|
||||||
|
[
|
||||||
|
xshaped[..., 0] * rope_cache[..., 0]
|
||||||
|
- xshaped[..., 1] * rope_cache[..., 1],
|
||||||
|
xshaped[..., 1] * rope_cache[..., 0]
|
||||||
|
+ xshaped[..., 0] * rope_cache[..., 1],
|
||||||
|
],
|
||||||
|
-1,
|
||||||
|
)
|
||||||
|
x_out2 = x_out2.flatten(3)
|
||||||
|
return torch.cat((x_out2, x_pass), dim=-1)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None):
|
||||||
|
# hidden_states: [sq, b, h]
|
||||||
|
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
|
||||||
|
mixed_x_layer = self.query_key_value(hidden_states)
|
||||||
|
|
||||||
|
(query_layer, key_layer, value_layer) = mixed_x_layer.split(
|
||||||
|
[
|
||||||
|
self.num_attention_heads_per_partition
|
||||||
|
* self.hidden_size_per_attention_head,
|
||||||
|
self.num_multi_query_groups_per_partition
|
||||||
|
* self.hidden_size_per_attention_head,
|
||||||
|
self.num_multi_query_groups_per_partition
|
||||||
|
* self.hidden_size_per_attention_head,
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
query_layer = query_layer.view(
|
||||||
|
query_layer.size()[:-1]
|
||||||
|
+ (
|
||||||
|
self.num_attention_heads_per_partition,
|
||||||
|
self.hidden_size_per_attention_head,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
key_layer = key_layer.view(
|
||||||
|
key_layer.size()[:-1]
|
||||||
|
+ (
|
||||||
|
self.num_multi_query_groups_per_partition,
|
||||||
|
self.hidden_size_per_attention_head,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
value_layer = value_layer.view(
|
||||||
|
value_layer.size()[:-1]
|
||||||
|
+ (
|
||||||
|
self.num_multi_query_groups_per_partition,
|
||||||
|
self.hidden_size_per_attention_head,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# apply relative positional encoding (rotary embedding)
|
||||||
|
if rotary_pos_emb is not None:
|
||||||
|
query_layer = self.apply_rotary_pos_emb(query_layer, rotary_pos_emb)
|
||||||
|
key_layer = self.apply_rotary_pos_emb(key_layer, rotary_pos_emb)
|
||||||
|
|
||||||
|
kv_cache = (key_layer, value_layer)
|
||||||
|
|
||||||
|
key_layer = key_layer.unsqueeze(-2)
|
||||||
|
key_layer = key_layer.expand(
|
||||||
|
-1,
|
||||||
|
-1,
|
||||||
|
-1,
|
||||||
|
self.num_attention_heads_per_partition
|
||||||
|
// self.num_multi_query_groups_per_partition,
|
||||||
|
-1,
|
||||||
|
)
|
||||||
|
key_layer = key_layer.contiguous().view(
|
||||||
|
key_layer.size()[:2]
|
||||||
|
+ (
|
||||||
|
self.num_attention_heads_per_partition,
|
||||||
|
self.hidden_size_per_attention_head,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
value_layer = value_layer.unsqueeze(-2)
|
||||||
|
value_layer = value_layer.expand(
|
||||||
|
-1,
|
||||||
|
-1,
|
||||||
|
-1,
|
||||||
|
self.num_attention_heads_per_partition
|
||||||
|
// self.num_multi_query_groups_per_partition,
|
||||||
|
-1,
|
||||||
|
)
|
||||||
|
value_layer = value_layer.contiguous().view(
|
||||||
|
value_layer.size()[:2]
|
||||||
|
+ (
|
||||||
|
self.num_attention_heads_per_partition,
|
||||||
|
self.hidden_size_per_attention_head,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# ==================================
|
||||||
|
# core attention computation
|
||||||
|
# ==================================
|
||||||
|
context_layer = self.core_attention(
|
||||||
|
query_layer, key_layer, value_layer, attention_mask
|
||||||
|
)
|
||||||
|
# =================
|
||||||
|
# Output. [sq, b, h]
|
||||||
|
# =================
|
||||||
|
output = self.dense(context_layer)
|
||||||
|
return output, kv_cache
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(torch.nn.Module):
|
||||||
|
"""MLP.
|
||||||
|
|
||||||
|
MLP will take the input with h hidden state, project it to 4*h
|
||||||
|
hidden dimension, perform nonlinear transformation, and project the
|
||||||
|
state back into h hidden dimension.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: ChatGLMConfig, device=None):
|
||||||
|
super(MLP, self).__init__()
|
||||||
|
|
||||||
|
self.add_bias = config.add_bias_linear
|
||||||
|
|
||||||
|
# Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
|
||||||
|
self.dense_h_to_4h = nn.Linear(
|
||||||
|
config.hidden_size,
|
||||||
|
config.ffn_hidden_size * 2,
|
||||||
|
bias=self.add_bias,
|
||||||
|
device=device,
|
||||||
|
dtype=config.torch_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
def swiglu(x):
|
||||||
|
x = torch.chunk(x, 2, dim=-1)
|
||||||
|
return F.silu(x[0]) * x[1]
|
||||||
|
|
||||||
|
self.activation_func = swiglu
|
||||||
|
|
||||||
|
# Project back to h.
|
||||||
|
self.dense_4h_to_h = nn.Linear(
|
||||||
|
config.ffn_hidden_size,
|
||||||
|
config.hidden_size,
|
||||||
|
bias=self.add_bias,
|
||||||
|
device=device,
|
||||||
|
dtype=config.torch_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
# [s, b, 4hp]
|
||||||
|
intermediate_parallel = self.dense_h_to_4h(hidden_states)
|
||||||
|
intermediate_parallel = self.activation_func(intermediate_parallel)
|
||||||
|
# [s, b, h]
|
||||||
|
output = self.dense_4h_to_h(intermediate_parallel)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class GLMBlock(torch.nn.Module):
|
||||||
|
"""A single transformer layer.
|
||||||
|
|
||||||
|
Transformer layer takes input with size [s, b, h] and returns an
|
||||||
|
output of the same size.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: ChatGLMConfig, layer_number, device=None):
|
||||||
|
super(GLMBlock, self).__init__()
|
||||||
|
self.layer_number = layer_number
|
||||||
|
|
||||||
|
self.apply_residual_connection_post_layernorm = (
|
||||||
|
config.apply_residual_connection_post_layernorm
|
||||||
|
)
|
||||||
|
|
||||||
|
self.fp32_residual_connection = config.fp32_residual_connection
|
||||||
|
|
||||||
|
LayerNormFunc = RMSNorm
|
||||||
|
# Layernorm on the input data.
|
||||||
|
self.input_layernorm = LayerNormFunc(
|
||||||
|
config.hidden_size,
|
||||||
|
eps=config.layernorm_epsilon,
|
||||||
|
device=device,
|
||||||
|
dtype=config.torch_dtype,
|
||||||
|
)
|
||||||
|
# Self attention.
|
||||||
|
self.self_attention = SelfAttention(config, layer_number, device=device)
|
||||||
|
self.hidden_dropout = config.hidden_dropout
|
||||||
|
# Layernorm on the attention output
|
||||||
|
self.post_attention_layernorm = LayerNormFunc(
|
||||||
|
config.hidden_size,
|
||||||
|
eps=config.layernorm_epsilon,
|
||||||
|
device=device,
|
||||||
|
dtype=config.torch_dtype,
|
||||||
|
)
|
||||||
|
# MLP
|
||||||
|
self.mlp = MLP(config, device=device)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None):
|
||||||
|
# hidden_states: [s, b, h]
|
||||||
|
# Layer norm at the beginning of the transformer layer.
|
||||||
|
layernorm_output = self.input_layernorm(hidden_states)
|
||||||
|
# Self attention.
|
||||||
|
attention_output, kv_cache = self.self_attention(
|
||||||
|
layernorm_output, attention_mask, rotary_pos_emb, kv_cache=kv_cache
|
||||||
|
)
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
layernorm_input = torch.nn.functional.dropout(
|
||||||
|
attention_output, p=self.hidden_dropout, training=self.training
|
||||||
|
)
|
||||||
|
layernorm_input = residual + layernorm_input
|
||||||
|
|
||||||
|
# Layer norm post the self attention.
|
||||||
|
layernorm_output = self.post_attention_layernorm(layernorm_input)
|
||||||
|
|
||||||
|
# MLP.
|
||||||
|
mlp_output = self.mlp(layernorm_output)
|
||||||
|
|
||||||
|
residual = layernorm_input
|
||||||
|
|
||||||
|
output = torch.nn.functional.dropout(
|
||||||
|
mlp_output, p=self.hidden_dropout, training=self.training
|
||||||
|
)
|
||||||
|
output = residual + output
|
||||||
|
return output, kv_cache
|
||||||
|
|
||||||
|
|
||||||
|
class GLMTransformer(torch.nn.Module):
|
||||||
|
def __init__(self, config: ChatGLMConfig, device=None):
|
||||||
|
super(GLMTransformer, self).__init__()
|
||||||
|
self.fp32_residual_connection = config.fp32_residual_connection
|
||||||
|
self.post_layer_norm = config.post_layer_norm
|
||||||
|
self.num_layers = config.num_layers
|
||||||
|
self.layers = []
|
||||||
|
for i in range(self.num_layers):
|
||||||
|
self.layers.append(GLMBlock(config, i + 1, device=device))
|
||||||
|
self.layers = torch.nn.ModuleList(self.layers)
|
||||||
|
self.final_layernorm = RMSNorm(
|
||||||
|
config.hidden_size,
|
||||||
|
eps=config.layernorm_epsilon,
|
||||||
|
device=device,
|
||||||
|
dtype=config.torch_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
rotary_pos_emb,
|
||||||
|
kv_caches=None,
|
||||||
|
use_cache: Optional[bool] = True,
|
||||||
|
):
|
||||||
|
kv_caches = [None for _ in range(self.num_layers)]
|
||||||
|
presents = () if use_cache else None
|
||||||
|
|
||||||
|
for index in range(self.num_layers):
|
||||||
|
layer = self.layers[index]
|
||||||
|
hidden_states, kv_cache = layer(
|
||||||
|
hidden_states, attention_mask, rotary_pos_emb, kv_cache=kv_caches[index]
|
||||||
|
)
|
||||||
|
if use_cache:
|
||||||
|
presents = presents + (kv_cache,)
|
||||||
|
hidden_states = self.final_layernorm(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Embedding(torch.nn.Module):
|
||||||
|
def __init__(self, config: ChatGLMConfig, device=None):
|
||||||
|
super(Embedding, self).__init__()
|
||||||
|
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.word_embeddings = nn.Embedding(
|
||||||
|
config.padded_vocab_size,
|
||||||
|
self.hidden_size,
|
||||||
|
dtype=config.torch_dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, input_ids):
|
||||||
|
# Embeddings.
|
||||||
|
words_embeddings = self.word_embeddings(input_ids)
|
||||||
|
embeddings = words_embeddings
|
||||||
|
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
|
||||||
|
embeddings = embeddings.transpose(0, 1).contiguous()
|
||||||
|
# If the input flag for fp32 residual connection is set, convert for float.
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class ChatGLMModel(nn.Module):
|
||||||
|
def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
|
||||||
|
super().__init__()
|
||||||
|
init_method = skip_init
|
||||||
|
init_kwargs = {}
|
||||||
|
if device is not None:
|
||||||
|
init_kwargs["device"] = device
|
||||||
|
self.embedding = init_method(Embedding, config, **init_kwargs)
|
||||||
|
self.num_layers = config.num_layers
|
||||||
|
self.multi_query_group_num = config.multi_query_group_num
|
||||||
|
self.kv_channels = config.kv_channels
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
# Rotary positional embeddings
|
||||||
|
self.seq_length = config.seq_length
|
||||||
|
rotary_dim = (
|
||||||
|
config.hidden_size // config.num_attention_heads
|
||||||
|
if config.kv_channels is None
|
||||||
|
else config.kv_channels
|
||||||
|
)
|
||||||
|
|
||||||
|
self.rotary_pos_emb = RotaryEmbedding(
|
||||||
|
rotary_dim // 2,
|
||||||
|
original_impl=config.original_rope,
|
||||||
|
device=device,
|
||||||
|
dtype=config.torch_dtype,
|
||||||
|
)
|
||||||
|
self.encoder = init_method(GLMTransformer, config, **init_kwargs)
|
||||||
|
self.output_layer = init_method(
|
||||||
|
nn.Linear,
|
||||||
|
config.hidden_size,
|
||||||
|
config.padded_vocab_size,
|
||||||
|
bias=False,
|
||||||
|
dtype=config.torch_dtype,
|
||||||
|
**init_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.BoolTensor] = None,
|
||||||
|
full_attention_mask: Optional[torch.BoolTensor] = None,
|
||||||
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
return_last_logit: Optional[bool] = False,
|
||||||
|
):
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_size, seq_length = input_ids.shape
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embedding(input_ids)
|
||||||
|
|
||||||
|
# Rotary positional embeddings
|
||||||
|
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
|
||||||
|
|
||||||
|
|
||||||
|
import plotly_express as px
|
||||||
|
img = px.imshow((rotary_pos_emb[:,:,0]*256).byte().cpu())
|
||||||
|
img.write_image("plot.png")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if position_ids is not None:
|
||||||
|
rotary_pos_emb = rotary_pos_emb[position_ids]
|
||||||
|
else:
|
||||||
|
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
|
||||||
|
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
|
||||||
|
hidden_states = self.encoder(
|
||||||
|
inputs_embeds,
|
||||||
|
full_attention_mask,
|
||||||
|
rotary_pos_emb=rotary_pos_emb,
|
||||||
|
kv_caches=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
|
)
|
||||||
|
if return_last_logit:
|
||||||
|
hidden_states = hidden_states[-1:]
|
||||||
|
lm_logits = self.output_layer(hidden_states)
|
||||||
|
lm_logits = lm_logits.transpose(0, 1).contiguous()
|
||||||
|
|
||||||
|
return lm_logits
|
||||||
|
|
||||||
|
|
||||||
|
class ChatGLMForConditionalGeneration(nn.Module):
|
||||||
|
def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.max_sequence_length = config.max_length
|
||||||
|
self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.main_input_name = "input_ids"
|
||||||
|
self.config = config
|
||||||
|
self.name_or_path = config.name_or_path
|
||||||
|
self.warnings_issued = {}
|
||||||
|
self.generation_config = GenerationConfig.from_model_config(config)
|
||||||
|
|
||||||
|
def from_pretrained(
|
||||||
|
cls,
|
||||||
|
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
||||||
|
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
state_dict = kwargs.pop("state_dict", None)
|
||||||
|
_ = kwargs.pop("mirror", None)
|
||||||
|
load_in_8bit = kwargs.pop("load_in_8bit", False)
|
||||||
|
load_in_4bit = kwargs.pop("load_in_4bit", False)
|
||||||
|
subfolder = kwargs.pop("subfolder", "")
|
||||||
|
|
||||||
|
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
|
||||||
|
# index of the files.
|
||||||
|
sharded_metadata = None
|
||||||
|
|
||||||
|
# if pretrained_model_name_or_path is not None:
|
||||||
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||||
|
archive_file = os.path.join(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
subfolder,
|
||||||
|
WEIGHTS_INDEX_NAME,
|
||||||
|
)
|
||||||
|
print(f"loading weights file {archive_file}")
|
||||||
|
resolved_archive_file = archive_file
|
||||||
|
|
||||||
|
with open(resolved_archive_file, "r") as f:
|
||||||
|
index = json.loads(f.read())
|
||||||
|
|
||||||
|
shard_filenames = sorted(set(index["weight_map"].values()))
|
||||||
|
sharded_metadata = index["metadata"]
|
||||||
|
sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys())
|
||||||
|
sharded_metadata["weight_map"] = index["weight_map"].copy()
|
||||||
|
|
||||||
|
resolved_archive_file = [
|
||||||
|
os.path.join(pretrained_model_name_or_path, subfolder, f)
|
||||||
|
for f in shard_filenames
|
||||||
|
]
|
||||||
|
|
||||||
|
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
|
||||||
|
config.name_or_path = pretrained_model_name_or_path
|
||||||
|
config = copy.deepcopy(config)
|
||||||
|
# make sure we use the model's config since the __init__ call might have copied it
|
||||||
|
config = cls.config
|
||||||
|
# restore default dtype
|
||||||
|
model = cls._load_pretrained_model(
|
||||||
|
state_dict,
|
||||||
|
loaded_state_dict_keys,
|
||||||
|
resolved_archive_file,
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
)
|
||||||
|
model.is_loaded_in_4bit = load_in_4bit
|
||||||
|
model.is_loaded_in_8bit = load_in_8bit
|
||||||
|
# Set model in evaluation mode to deactivate DropOut modules by default
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _load_state_dict_into_model(self, model_to_load, state_dict, start_prefix):
|
||||||
|
metadata = getattr(state_dict, "_metadata", None)
|
||||||
|
state_dict = state_dict.copy()
|
||||||
|
if metadata is not None:
|
||||||
|
state_dict._metadata = metadata
|
||||||
|
error_msgs = []
|
||||||
|
|
||||||
|
def load(module: nn.Module, state_dict, prefix=""):
|
||||||
|
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||||
|
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
|
||||||
|
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
|
||||||
|
module._load_from_state_dict(*args)
|
||||||
|
|
||||||
|
for name, child in module._modules.items():
|
||||||
|
if child is not None:
|
||||||
|
load(child, state_dict, prefix + name + ".")
|
||||||
|
|
||||||
|
load(model_to_load, state_dict, prefix=start_prefix)
|
||||||
|
del state_dict
|
||||||
|
return error_msgs
|
||||||
|
|
||||||
|
def _load_pretrained_model(
|
||||||
|
cls,
|
||||||
|
state_dict,
|
||||||
|
loaded_keys,
|
||||||
|
resolved_archive_file,
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
):
|
||||||
|
# Retrieve missing & unexpected_keys
|
||||||
|
model_state_dict = cls.state_dict()
|
||||||
|
expected_keys = list(model_state_dict.keys())
|
||||||
|
loaded_keys = [key for key in loaded_keys]
|
||||||
|
|
||||||
|
unexpected_keys = set(loaded_keys) - set(expected_keys)
|
||||||
|
model_buffers = {n for n, _ in cls.named_buffers()}
|
||||||
|
unexpected_keys = list(unexpected_keys - model_buffers)
|
||||||
|
|
||||||
|
ptrs = collections.defaultdict(list)
|
||||||
|
for name, tensor in cls.state_dict().items():
|
||||||
|
ptrs[(tensor.device, storage_ptr(tensor), storage_size(tensor))].append(
|
||||||
|
name
|
||||||
|
)
|
||||||
|
# retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights.
|
||||||
|
|
||||||
|
_loaded_keys = loaded_keys
|
||||||
|
for module_name, module in cls.named_modules():
|
||||||
|
loaded_keys = [
|
||||||
|
k.replace(f"{module_name}.", "")
|
||||||
|
for k in _loaded_keys
|
||||||
|
if k.startswith(f"{module_name}.")
|
||||||
|
]
|
||||||
|
if len(set(module.state_dict().keys()) - set(loaded_keys)) == 0:
|
||||||
|
module._is_hf_initialized = True
|
||||||
|
|
||||||
|
# Make sure we are able to load base models as well as derived models (with heads)
|
||||||
|
start_prefix = ""
|
||||||
|
model_to_load = cls
|
||||||
|
error_msgs = []
|
||||||
|
if len(resolved_archive_file) > 1:
|
||||||
|
resolved_archive_file = tqdm_lib.tqdm(
|
||||||
|
resolved_archive_file, desc="Loading checkpoint shards"
|
||||||
|
)
|
||||||
|
for shard_file in resolved_archive_file:
|
||||||
|
# Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload.
|
||||||
|
state_dict = torch.load(shard_file, map_location="cpu")
|
||||||
|
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
||||||
|
# matching the weights in the model.
|
||||||
|
error_msgs += cls._load_state_dict_into_model(
|
||||||
|
model_to_load, state_dict, start_prefix
|
||||||
|
)
|
||||||
|
# force memory release
|
||||||
|
del state_dict
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"All model checkpoint weights were used when initializing {cls.__class__.__name__}.\n"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"All the weights of {cls.__class__.__name__} were initialized from the model checkpoint at"
|
||||||
|
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
|
||||||
|
f" was trained on, you can already use {cls.__class__.__name__} for predictions without further"
|
||||||
|
" training."
|
||||||
|
)
|
||||||
|
return cls
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def chat(
|
||||||
|
self,
|
||||||
|
tokenizer,
|
||||||
|
query: str,
|
||||||
|
history: List[Tuple[str, str]] = None,
|
||||||
|
role: str = "user",
|
||||||
|
):
|
||||||
|
if history is None:
|
||||||
|
history = []
|
||||||
|
inputs = tokenizer.build_chat_input(query, history=history, role=role)
|
||||||
|
inputs = inputs.to(next(self.parameters()).device)
|
||||||
|
|
||||||
|
generation_config = copy.deepcopy(self.generation_config)
|
||||||
|
inputs_tensor = inputs["input_ids"]
|
||||||
|
input_ids = inputs_tensor.repeat_interleave(
|
||||||
|
generation_config.num_return_sequences, dim=0
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = self.sample(
|
||||||
|
input_ids,
|
||||||
|
pad_token_id=generation_config.pad_token_id,
|
||||||
|
eos_token_id=generation_config.eos_token_id,
|
||||||
|
output_hidden_states = generation_config.output_hidden_states,
|
||||||
|
use_cache = generation_config.use_cache
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) : -1]
|
||||||
|
response = tokenizer.decode(outputs)
|
||||||
|
history.append({"role": role, "content": query})
|
||||||
|
return response, history
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
pad_token_id: Optional[int] = None,
|
||||||
|
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
use_cache: Optional[bool] = None
|
||||||
|
):
|
||||||
|
if isinstance(eos_token_id, int):
|
||||||
|
eos_token_id = [eos_token_id]
|
||||||
|
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device)
|
||||||
|
|
||||||
|
unfinished_sequences = torch.ones(
|
||||||
|
input_ids.shape[0], dtype=torch.long, device=input_ids.device
|
||||||
|
)
|
||||||
|
|
||||||
|
this_peer_finished = False # used by synced_gpus only
|
||||||
|
while True:
|
||||||
|
input_ids_in = input_ids
|
||||||
|
batch_size, seq_length = input_ids_in.shape
|
||||||
|
position_ids_in = (
|
||||||
|
torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.repeat(batch_size, 1)
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
model_inputs = {
|
||||||
|
"input_ids": input_ids_in,
|
||||||
|
"past_key_values": None,
|
||||||
|
"position_ids": position_ids_in,
|
||||||
|
"return_last_logit": True,
|
||||||
|
"use_cache": use_cache,
|
||||||
|
}
|
||||||
|
|
||||||
|
logits = self.transformer(
|
||||||
|
**model_inputs,
|
||||||
|
return_dict=True,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
)
|
||||||
|
next_token_logits = logits[:, -1, :]
|
||||||
|
next_token_scores = next_token_logits
|
||||||
|
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
||||||
|
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||||
|
|
||||||
|
# finished sentences should have their next token be a padding token
|
||||||
|
if eos_token_id is not None:
|
||||||
|
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
|
||||||
|
1 - unfinished_sequences
|
||||||
|
)
|
||||||
|
|
||||||
|
# update generated ids, model inputs, and length for next step
|
||||||
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||||
|
|
||||||
|
# if eos_token was found in one sentence, set sentence to finished
|
||||||
|
if eos_token_id_tensor is not None:
|
||||||
|
unfinished_sequences = unfinished_sequences.mul(
|
||||||
|
next_tokens.tile(eos_token_id_tensor.shape[0], 1)
|
||||||
|
.ne(eos_token_id_tensor.unsqueeze(1))
|
||||||
|
.prod(dim=0)
|
||||||
|
)
|
||||||
|
if unfinished_sequences.max() == 0:
|
||||||
|
this_peer_finished = True
|
||||||
|
if this_peer_finished:
|
||||||
|
break
|
||||||
|
|
||||||
|
return input_ids
|
|
@ -0,0 +1,161 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from typing import List, Optional, Union, Dict
|
||||||
|
from sentencepiece import SentencePieceProcessor
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class SPTokenizer:
|
||||||
|
def __init__(self, model_path: str):
|
||||||
|
# reload tokenizer
|
||||||
|
assert os.path.isfile(model_path), model_path
|
||||||
|
self.sp_model = SentencePieceProcessor(model_file=model_path)
|
||||||
|
|
||||||
|
# BOS / EOS token IDs
|
||||||
|
self.n_words: int = self.sp_model.vocab_size()
|
||||||
|
self.bos_id: int = self.sp_model.bos_id()
|
||||||
|
self.eos_id: int = self.sp_model.eos_id()
|
||||||
|
self.pad_id: int = self.sp_model.unk_id()
|
||||||
|
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
|
||||||
|
|
||||||
|
special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop", "<|system|>", "<|user|>", "<|assistant|>",
|
||||||
|
"<|observation|>"]
|
||||||
|
self.special_tokens = {}
|
||||||
|
self.index_special_tokens = {}
|
||||||
|
for token in special_tokens:
|
||||||
|
self.special_tokens[token] = self.n_words
|
||||||
|
self.index_special_tokens[self.n_words] = token
|
||||||
|
self.n_words += 1
|
||||||
|
|
||||||
|
def tokenize(self, s: str):
|
||||||
|
return self.sp_model.EncodeAsPieces(s)
|
||||||
|
|
||||||
|
def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
|
||||||
|
assert type(s) is str
|
||||||
|
t = self.sp_model.encode(s)
|
||||||
|
if bos:
|
||||||
|
t = [self.bos_id] + t
|
||||||
|
if eos:
|
||||||
|
t = t + [self.eos_id]
|
||||||
|
return t
|
||||||
|
|
||||||
|
def decode(self, t: List[int]) -> str:
|
||||||
|
text, buffer = "", []
|
||||||
|
for token in t:
|
||||||
|
if token in self.index_special_tokens:
|
||||||
|
if buffer:
|
||||||
|
text += self.sp_model.decode(buffer)
|
||||||
|
buffer = []
|
||||||
|
text += self.index_special_tokens[token]
|
||||||
|
else:
|
||||||
|
buffer.append(token)
|
||||||
|
if buffer:
|
||||||
|
text += self.sp_model.decode(buffer)
|
||||||
|
return text
|
||||||
|
|
||||||
|
def decode_tokens(self, tokens: List[str]) -> str:
|
||||||
|
text = self.sp_model.DecodePieces(tokens)
|
||||||
|
return text
|
||||||
|
|
||||||
|
def convert_token_to_id(self, token):
|
||||||
|
""" Converts a token (str) in an id using the vocab. """
|
||||||
|
if token in self.special_tokens:
|
||||||
|
return self.special_tokens[token]
|
||||||
|
return self.sp_model.PieceToId(token)
|
||||||
|
|
||||||
|
def convert_id_to_token(self, index):
|
||||||
|
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||||
|
if index in self.index_special_tokens:
|
||||||
|
return self.index_special_tokens[index]
|
||||||
|
if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:
|
||||||
|
return ""
|
||||||
|
return self.sp_model.IdToPiece(index)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||||
|
vocab_files_names = {"vocab_file": "tokenizer.model"}
|
||||||
|
|
||||||
|
model_input_names = ["input_ids", "attention_mask", "position_ids"]
|
||||||
|
|
||||||
|
def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, **kwargs):
|
||||||
|
self.name = "GLMTokenizer"
|
||||||
|
|
||||||
|
self.vocab_file = vocab_file
|
||||||
|
self.tokenizer = SPTokenizer(vocab_file)
|
||||||
|
self.special_tokens = {
|
||||||
|
"<bos>": self.tokenizer.bos_id,
|
||||||
|
"<eos>": self.tokenizer.eos_id,
|
||||||
|
"<pad>": self.tokenizer.pad_id
|
||||||
|
}
|
||||||
|
super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces, **kwargs)
|
||||||
|
|
||||||
|
def get_command(self, token):
|
||||||
|
if token in self.special_tokens:
|
||||||
|
return self.special_tokens[token]
|
||||||
|
assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
|
||||||
|
return self.tokenizer.special_tokens[token]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def unk_token(self) -> str:
|
||||||
|
return "<unk>"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pad_token(self) -> str:
|
||||||
|
return "<unk>"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pad_token_id(self):
|
||||||
|
return self.get_command("<pad>")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eos_token(self) -> str:
|
||||||
|
return "</s>"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eos_token_id(self):
|
||||||
|
return self.get_command("<eos>")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self):
|
||||||
|
return self.tokenizer.n_words
|
||||||
|
|
||||||
|
def get_vocab(self):
|
||||||
|
""" Returns vocab as a dict """
|
||||||
|
vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
|
||||||
|
vocab.update(self.added_tokens_encoder)
|
||||||
|
return vocab
|
||||||
|
|
||||||
|
def _tokenize(self, text, **kwargs):
|
||||||
|
return self.tokenizer.tokenize(text)
|
||||||
|
|
||||||
|
def _convert_token_to_id(self, token):
|
||||||
|
""" Converts a token (str) in an id using the vocab. """
|
||||||
|
return self.tokenizer.convert_token_to_id(token)
|
||||||
|
|
||||||
|
def _convert_id_to_token(self, index):
|
||||||
|
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||||
|
return self.tokenizer.convert_id_to_token(index)
|
||||||
|
|
||||||
|
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||||
|
return self.tokenizer.decode_tokens(tokens)
|
||||||
|
|
||||||
|
def build_single_message(self, role, metadata, message):
|
||||||
|
assert role in ["system", "user", "assistant", "observation"], role
|
||||||
|
role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n")
|
||||||
|
message_tokens = self.tokenizer.encode(message)
|
||||||
|
tokens = role_tokens + message_tokens
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def build_chat_input(self, query, history=None, role="user"):
|
||||||
|
if history is None:
|
||||||
|
history = []
|
||||||
|
input_ids = []
|
||||||
|
for item in history:
|
||||||
|
content = item["content"]
|
||||||
|
if item["role"] == "system" and "tools" in item:
|
||||||
|
content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False)
|
||||||
|
input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content))
|
||||||
|
input_ids.extend(self.build_single_message(role, "", query))
|
||||||
|
input_ids.extend([self.get_command("<|assistant|>")])
|
||||||
|
return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)
|
Binary file not shown.
|
@ -0,0 +1,12 @@
|
||||||
|
{
|
||||||
|
"name_or_path": "THUDM/chatglm3-6b",
|
||||||
|
"remove_space": false,
|
||||||
|
"do_lower_case": false,
|
||||||
|
"tokenizer_class": "ChatGLMTokenizer",
|
||||||
|
"auto_map": {
|
||||||
|
"AutoTokenizer": [
|
||||||
|
"tokenization_chatglm.ChatGLMTokenizer",
|
||||||
|
null
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,64 @@
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
from chatglm import ChatGLMForConditionalGeneration
|
||||||
|
from chatglm import ChatGLMTokenizer
|
||||||
|
|
||||||
|
from transformers import AutoConfig
|
||||||
|
|
||||||
|
pretrained_model_name_or_path = "../ZhipuAI/chatglm3-6b"
|
||||||
|
config, kwargs = AutoConfig.from_pretrained(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
return_unused_kwargs=True,
|
||||||
|
trust_remote_code=True,
|
||||||
|
code_revision=None,
|
||||||
|
_commit_hash=None,
|
||||||
|
)
|
||||||
|
glm = ChatGLMForConditionalGeneration(config)
|
||||||
|
|
||||||
|
|
||||||
|
tokenizer_config_file = "./chatglm/tokenizer_config.json"
|
||||||
|
if tokenizer_config_file is not None:
|
||||||
|
with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
|
||||||
|
init_kwargs = json.load(tokenizer_config_handle)
|
||||||
|
init_kwargs.pop("tokenizer_class", None)
|
||||||
|
init_kwargs.pop("tokenizer_file", None)
|
||||||
|
saved_init_inputs = init_kwargs.pop("init_inputs", ())
|
||||||
|
init_inputs = saved_init_inputs
|
||||||
|
init_kwargs["vocab_file"] = './chatglm/tokenizer.model'
|
||||||
|
init_kwargs["added_tokens_file"] = None
|
||||||
|
init_kwargs["special_tokens_map_file"] = None
|
||||||
|
init_kwargs["tokenizer_file"] = None
|
||||||
|
init_kwargs["name_or_path"] = pretrained_model_name_or_path
|
||||||
|
tokenizer = ChatGLMTokenizer(*init_inputs, **init_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
glm = glm.from_pretrained(pretrained_model_name_or_path, config=config).half().cuda()
|
||||||
|
glm = glm.eval()
|
||||||
|
response, history = glm.chat(tokenizer, "colin", history=[])
|
||||||
|
print(response)
|
||||||
|
response, history = glm.chat(tokenizer, "你好", history=history)
|
||||||
|
print(response)
|
||||||
|
# response, history = glm.chat(tokenizer, "你是一个心理学专家,请问晚上睡不着应该怎么办", history=history)
|
||||||
|
# print(response)
|
||||||
|
|
||||||
|
|
||||||
|
# import plotly_express as px
|
||||||
|
# px.imshow(ron)
|
||||||
|
# gapminder = px.data.gapminder()
|
||||||
|
# gapminder2007 = gapminder.query('year == 2007')
|
||||||
|
# px.scatter(gapminder2007, x='gdpPercap', y='lifeExp')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# from modelscope import AutoTokenizer, AutoModel, snapshot_download
|
||||||
|
# model_dir = snapshot_download("ZhipuAI/chatglm3-6b", cache_dir="./chatglm", revision="v1.0.0")
|
||||||
|
# model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).half().cuda()
|
||||||
|
# tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
||||||
|
# model = model.eval()
|
||||||
|
# response, history = model.chat(tokenizer, "colin", history=[])
|
||||||
|
# print(response)
|
||||||
|
# response, history = model.chat(tokenizer, "你好", history=history)
|
||||||
|
# print(response)
|
||||||
|
# # response, history = model.chat(tokenizer, "你是一个心理学专家,请问晚上睡不着应该怎么办", history=history)
|
||||||
|
# # print(response)
|
Loading…
Reference in New Issue