Add qwen and refine folders.

This commit is contained in:
Colin 2024-01-03 20:26:26 +08:00
parent 0fa38b7815
commit 3a4e99f7e3
9 changed files with 1576 additions and 43 deletions

View File

@ -1,19 +1,23 @@
import sys
sys.path.append("..")
import json import json
import torch import torch
from chatglm import ChatGLMForConditionalGeneration from modeling_chatglm import ChatGLMForConditionalGeneration
from chatglm import ChatGLMTokenizer from tokenization_chatglm import ChatGLMTokenizer
from modelscope import snapshot_download
from transformers import AutoConfig
from tools import show from tools import show
from transformers import AutoConfig
seed = 4321 seed = 4321
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
pretrained_model_name_or_path = snapshot_download("ZhipuAI/chatglm3-6b")
pretrained_model_name_or_path = "../ZhipuAI/chatglm3-6b"
config, kwargs = AutoConfig.from_pretrained( config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, pretrained_model_name_or_path,
return_unused_kwargs=True, return_unused_kwargs=True,
@ -24,7 +28,7 @@ config, kwargs = AutoConfig.from_pretrained(
glm = ChatGLMForConditionalGeneration(config) glm = ChatGLMForConditionalGeneration(config)
tokenizer_config_file = "./chatglm/tokenizer_config.json" tokenizer_config_file = "./tokenizer_config.json"
if tokenizer_config_file is not None: if tokenizer_config_file is not None:
with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle: with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
init_kwargs = json.load(tokenizer_config_handle) init_kwargs = json.load(tokenizer_config_handle)
@ -32,7 +36,7 @@ if tokenizer_config_file is not None:
init_kwargs.pop("tokenizer_file", None) init_kwargs.pop("tokenizer_file", None)
saved_init_inputs = init_kwargs.pop("init_inputs", ()) saved_init_inputs = init_kwargs.pop("init_inputs", ())
init_inputs = saved_init_inputs init_inputs = saved_init_inputs
init_kwargs["vocab_file"] = "./chatglm/tokenizer.model" init_kwargs["vocab_file"] = "./tokenizer.model"
init_kwargs["added_tokens_file"] = None init_kwargs["added_tokens_file"] = None
init_kwargs["special_tokens_map_file"] = None init_kwargs["special_tokens_map_file"] = None
init_kwargs["tokenizer_file"] = None init_kwargs["tokenizer_file"] = None

View File

@ -19,14 +19,16 @@ from safetensors.torch import storage_ptr, storage_size
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.generation import GenerationConfig from transformers.generation import GenerationConfig
from chatglm import ChatGLMConfig from configuration_chatglm import ChatGLMConfig
from tools import show from tools import show
class RotaryEmbedding(nn.Module): class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, original_impl=False, device=None, dtype=None): def __init__(self, dim: int, original_impl=False, device=None, dtype=None):
super().__init__() super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim)) inv_freq = 1.0 / (
10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim)
)
self.register_buffer("inv_freq", inv_freq) self.register_buffer("inv_freq", inv_freq)
self.dim = dim self.dim = dim
self.original_impl = original_impl self.original_impl = original_impl
@ -35,7 +37,13 @@ class RotaryEmbedding(nn.Module):
dtype = self.inv_freq.dtype dtype = self.inv_freq.dtype
device = self.inv_freq.device device = self.inv_freq.device
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
theta = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.float, device=device) / self.dim)) theta = 1.0 / (
base
** (
torch.arange(0, self.dim, 2, dtype=torch.float, device=device)
/ self.dim
)
)
# Create position indexes `[0, 1, ..., max_seq_len - 1]` # Create position indexes `[0, 1, ..., max_seq_len - 1]`
seq_idx = torch.arange(max_seq_len, dtype=torch.float, device=device) seq_idx = torch.arange(max_seq_len, dtype=torch.float, device=device)
# Calculate the product of position index and $\theta_i$ # Calculate the product of position index and $\theta_i$
@ -50,7 +58,9 @@ class RotaryEmbedding(nn.Module):
class RMSNorm(torch.nn.Module): class RMSNorm(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
super().__init__() super().__init__()
self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype)) self.weight = torch.nn.Parameter(
torch.empty(normalized_shape, device=device, dtype=dtype)
)
self.eps = eps self.eps = eps
def forward(self, hidden_states: torch.Tensor): def forward(self, hidden_states: torch.Tensor):
@ -70,7 +80,9 @@ class CoreAttention(torch.nn.Module):
projection_size = config.kv_channels * config.num_attention_heads projection_size = config.kv_channels * config.num_attention_heads
# Per attention head and per partition values. # Per attention head and per partition values.
self.hidden_size_per_partition = projection_size self.hidden_size_per_partition = projection_size
self.hidden_size_per_attention_head = projection_size // config.num_attention_heads self.hidden_size_per_attention_head = (
projection_size // config.num_attention_heads
)
self.num_attention_heads_per_partition = config.num_attention_heads self.num_attention_heads_per_partition = config.num_attention_heads
coeff = None coeff = None
@ -82,13 +94,17 @@ class CoreAttention(torch.nn.Module):
self.attention_dropout = torch.nn.Dropout(config.attention_dropout) self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
def forward(self, query_layer, key_layer, value_layer): def forward(self, query_layer, key_layer, value_layer):
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] query_layer, key_layer, value_layer = [
k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]
]
if query_layer.shape[2] == key_layer.shape[2]: if query_layer.shape[2] == key_layer.shape[2]:
context_layer = torch.nn.functional.scaled_dot_product_attention( context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer, key_layer, value_layer, is_causal=True query_layer, key_layer, value_layer, is_causal=True
) )
context_layer = context_layer.permute(2, 0, 1, 3) context_layer = context_layer.permute(2, 0, 1, 3)
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) new_context_layer_shape = context_layer.size()[:-2] + (
self.hidden_size_per_partition,
)
context_layer = context_layer.reshape(*new_context_layer_shape) context_layer = context_layer.reshape(*new_context_layer_shape)
return context_layer return context_layer
@ -98,13 +114,16 @@ class SelfAttention(torch.nn.Module):
super(SelfAttention, self).__init__() super(SelfAttention, self).__init__()
self.layer_number = max(1, layer_number) self.layer_number = max(1, layer_number)
self.projection_size = config.kv_channels * config.num_attention_heads self.projection_size = config.kv_channels * config.num_attention_heads
self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads self.hidden_size_per_attention_head = (
self.projection_size // config.num_attention_heads
)
self.num_attention_heads_per_partition = config.num_attention_heads self.num_attention_heads_per_partition = config.num_attention_heads
self.multi_query_attention = config.multi_query_attention self.multi_query_attention = config.multi_query_attention
self.qkv_hidden_size = 3 * self.projection_size self.qkv_hidden_size = 3 * self.projection_size
self.num_multi_query_groups_per_partition = config.multi_query_group_num self.num_multi_query_groups_per_partition = config.multi_query_group_num
self.qkv_hidden_size = ( self.qkv_hidden_size = (
self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num self.projection_size
+ 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
) )
self.query_key_value = nn.Linear( self.query_key_value = nn.Linear(
config.hidden_size, config.hidden_size,
@ -144,9 +163,12 @@ class SelfAttention(torch.nn.Module):
(query_layer, key_layer, value_layer) = mixed_x_layer.split( (query_layer, key_layer, value_layer) = mixed_x_layer.split(
[ [
self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, self.num_attention_heads_per_partition
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, * self.hidden_size_per_attention_head,
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, self.num_multi_query_groups_per_partition
* self.hidden_size_per_attention_head,
self.num_multi_query_groups_per_partition
* self.hidden_size_per_attention_head,
], ],
dim=-1, dim=-1,
) )
@ -182,7 +204,8 @@ class SelfAttention(torch.nn.Module):
-1, -1,
-1, -1,
-1, -1,
self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, self.num_attention_heads_per_partition
// self.num_multi_query_groups_per_partition,
-1, -1,
) )
key_layer = key_layer.contiguous().view( key_layer = key_layer.contiguous().view(
@ -197,7 +220,8 @@ class SelfAttention(torch.nn.Module):
-1, -1,
-1, -1,
-1, -1,
self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, self.num_attention_heads_per_partition
// self.num_multi_query_groups_per_partition,
-1, -1,
) )
value_layer = value_layer.contiguous().view( value_layer = value_layer.contiguous().view(
@ -224,9 +248,11 @@ class MLP(torch.nn.Module):
device=device, device=device,
dtype=config.torch_dtype, dtype=config.torch_dtype,
) )
def swiglu(x): def swiglu(x):
x = torch.chunk(x, 2, dim=-1) x = torch.chunk(x, 2, dim=-1)
return F.silu(x[0]) * x[1] return F.silu(x[0]) * x[1]
self.activation_func = swiglu self.activation_func = swiglu
self.dense_4h_to_h = nn.Linear( self.dense_4h_to_h = nn.Linear(
config.ffn_hidden_size, config.ffn_hidden_size,
@ -254,7 +280,9 @@ class GLMBlock(torch.nn.Module):
super(GLMBlock, self).__init__() super(GLMBlock, self).__init__()
self.layer_number = layer_number self.layer_number = layer_number
self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm self.apply_residual_connection_post_layernorm = (
config.apply_residual_connection_post_layernorm
)
self.fp32_residual_connection = config.fp32_residual_connection self.fp32_residual_connection = config.fp32_residual_connection
@ -286,7 +314,9 @@ class GLMBlock(torch.nn.Module):
attention_output = self.self_attention(layernorm_output, rotary_pos_emb) attention_output = self.self_attention(layernorm_output, rotary_pos_emb)
residual = hidden_states residual = hidden_states
layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training) layernorm_input = torch.nn.functional.dropout(
attention_output, p=self.hidden_dropout, training=self.training
)
layernorm_input = residual + layernorm_input layernorm_input = residual + layernorm_input
# Layer norm post the self attention. # Layer norm post the self attention.
@ -297,7 +327,9 @@ class GLMBlock(torch.nn.Module):
residual = layernorm_input residual = layernorm_input
output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training) output = torch.nn.functional.dropout(
mlp_output, p=self.hidden_dropout, training=self.training
)
output = residual + output output = residual + output
return output return output
@ -365,7 +397,9 @@ class ChatGLMModel(nn.Module):
# Rotary positional embeddings # Rotary positional embeddings
self.seq_length = config.seq_length self.seq_length = config.seq_length
rotary_dim = ( rotary_dim = (
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels config.hidden_size // config.num_attention_heads
if config.kv_channels is None
else config.kv_channels
) )
self.rotary_pos_emb = RotaryEmbedding( self.rotary_pos_emb = RotaryEmbedding(
@ -392,7 +426,9 @@ class ChatGLMModel(nn.Module):
tokenizer=None, tokenizer=None,
): ):
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
) )
inputs_embeds = self.embedding(input_ids) inputs_embeds = self.embedding(input_ids)
@ -410,7 +446,7 @@ class ChatGLMModel(nn.Module):
probs = nn.functional.softmax(next_token_logits, dim=-1) probs = nn.functional.softmax(next_token_logits, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
return next_tokens return probs, next_tokens
class ChatGLMForConditionalGeneration(nn.Module): class ChatGLMForConditionalGeneration(nn.Module):
@ -427,21 +463,26 @@ class ChatGLMForConditionalGeneration(nn.Module):
self.warnings_issued = {} self.warnings_issued = {}
self.generation_config = GenerationConfig.from_model_config(config) self.generation_config = GenerationConfig.from_model_config(config)
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]]): def from_pretrained(
cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]]
):
load_in_8bit = False load_in_8bit = False
load_in_4bit = False load_in_4bit = False
pretrained_model_name_or_path = str(pretrained_model_name_or_path) pretrained_model_name_or_path = str(pretrained_model_name_or_path)
resolved_archive_file = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin.index.json") resolved_archive_file = os.path.join(
pretrained_model_name_or_path, "pytorch_model.bin.index.json"
)
print(f"loading weights file {resolved_archive_file}") print(f"loading weights file {resolved_archive_file}")
with open(resolved_archive_file, "r") as f: with open(resolved_archive_file, "r") as f:
index = json.loads(f.read()) index = json.loads(f.read())
shard_filenames = sorted(set(index["weight_map"].values())) shard_filenames = sorted(set(index["weight_map"].values()))
resolved_archive_file = [os.path.join(pretrained_model_name_or_path, f) for f in shard_filenames] resolved_archive_file = [
os.path.join(pretrained_model_name_or_path, f) for f in shard_filenames
]
model = cls._load_pretrained_model(resolved_archive_file) model = cls._load_pretrained_model(resolved_archive_file)
model.is_loaded_in_4bit = load_in_4bit model.is_loaded_in_4bit = load_in_4bit
model.is_loaded_in_8bit = load_in_8bit model.is_loaded_in_8bit = load_in_8bit
model.eval() # Set model in evaluation mode to deactivate DropOut modules by default
return model return model
def _load_state_dict_into_model(self, model_to_load, state_dict, start_prefix): def _load_state_dict_into_model(self, model_to_load, state_dict, start_prefix):
@ -470,15 +511,21 @@ class ChatGLMForConditionalGeneration(nn.Module):
model_to_load = cls model_to_load = cls
error_msgs = [] error_msgs = []
if len(resolved_archive_file) > 1: if len(resolved_archive_file) > 1:
resolved_archive_file = tqdm_lib.tqdm(resolved_archive_file, desc="Loading checkpoint shards") resolved_archive_file = tqdm_lib.tqdm(
resolved_archive_file, desc="Loading checkpoint shards"
)
for shard_file in resolved_archive_file: for shard_file in resolved_archive_file:
state_dict = torch.load(shard_file, map_location="cpu") state_dict = torch.load(shard_file, map_location="cpu")
error_msgs += cls._load_state_dict_into_model(model_to_load, state_dict, start_prefix) error_msgs += cls._load_state_dict_into_model(
model_to_load, state_dict, start_prefix
)
del state_dict # force memory release del state_dict # force memory release
gc.collect() gc.collect()
print(f"All model checkpoint weights were used when initializing {cls.__class__.__name__}.\n") print(
f"All model checkpoint weights were used when initializing {cls.__class__.__name__}.\n"
)
return cls return cls
@torch.inference_mode() @torch.inference_mode()
@ -496,7 +543,9 @@ class ChatGLMForConditionalGeneration(nn.Module):
generation_config = copy.deepcopy(self.generation_config) generation_config = copy.deepcopy(self.generation_config)
inputs_tensor = inputs["input_ids"] inputs_tensor = inputs["input_ids"]
input_ids = inputs_tensor.repeat_interleave(generation_config.num_return_sequences, dim=0) input_ids = inputs_tensor.repeat_interleave(
generation_config.num_return_sequences, dim=0
)
outputs = self.sample( outputs = self.sample(
input_ids, input_ids,
@ -523,17 +572,21 @@ class ChatGLMForConditionalGeneration(nn.Module):
eos_token_id = [eos_token_id] eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device)
isFinished = torch.zeros(input_ids.shape[0], dtype=torch.long, device=input_ids.device) isFinished = torch.zeros(
input_ids.shape[0], dtype=torch.long, device=input_ids.device
)
# token_count = 0 # token_count = 0
while True: while True:
input_ids_in = input_ids input_ids_in = input_ids
batch_size, seq_length = input_ids_in.shape batch_size, seq_length = input_ids_in.shape
position_ids_in = ( position_ids_in = (
torch.arange(seq_length, dtype=torch.long, device=input_ids.device).unsqueeze(0).repeat(batch_size, 1) torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
.unsqueeze(0)
.repeat(batch_size, 1)
) )
model_inputs = {"input_ids": input_ids_in, "position_ids": position_ids_in} model_inputs = {"input_ids": input_ids_in, "position_ids": position_ids_in}
next_tokens = self.transformer( probs, next_tokens = self.transformer(
**model_inputs, **model_inputs,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
tokenizer=tokenizer, tokenizer=tokenizer,
@ -549,3 +602,41 @@ class ChatGLMForConditionalGeneration(nn.Module):
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
return input_ids return input_ids
def backward(
self,
tokenizer,
query: str,
):
inputs = tokenizer.build_chat_input(query, history=[], role="user")
inputs = inputs.to(next(self.parameters()).device)
generation_config = copy.deepcopy(self.generation_config)
inputs_tensor = inputs["input_ids"]
input_ids = inputs_tensor.repeat_interleave(
generation_config.num_return_sequences, dim=0
)
input_ids_in = input_ids
batch_size, seq_length = input_ids_in.shape
position_ids_in = (
torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
.unsqueeze(0)
.repeat(batch_size, 1)
)
model_inputs = {"input_ids": input_ids_in, "position_ids": position_ids_in}
probs, next_tokens = self.transformer(
**model_inputs,
output_hidden_states=None,
tokenizer=tokenizer,
)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
# probs_target = probs
# probs_target[0, next_tokens] = probs_target[0, next_tokens] * 1.1
loss = probs[0, next_tokens]
loss.backward()
return loss

View File

@ -1,13 +1,18 @@
import sys
sys.path.append("..")
import json import json
import torch import torch
from tools import show from tools import show
from chatglm import ChatGLMTokenizer from chatglm import ChatGLMTokenizer
from modelscope import snapshot_download
pretrained_model_name_or_path = "../ZhipuAI/chatglm3-6b" pretrained_model_name_or_path = snapshot_download("ZhipuAI/chatglm3-6b")
tokenizer_config_file = "./tokenizer_config.json"
tokenizer_config_file = "./chatglm/tokenizer_config.json"
if tokenizer_config_file is not None: if tokenizer_config_file is not None:
with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle: with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
init_kwargs = json.load(tokenizer_config_handle) init_kwargs = json.load(tokenizer_config_handle)
@ -15,7 +20,7 @@ if tokenizer_config_file is not None:
init_kwargs.pop("tokenizer_file", None) init_kwargs.pop("tokenizer_file", None)
saved_init_inputs = init_kwargs.pop("init_inputs", ()) saved_init_inputs = init_kwargs.pop("init_inputs", ())
init_inputs = saved_init_inputs init_inputs = saved_init_inputs
init_kwargs["vocab_file"] = "./chatglm/tokenizer.model" init_kwargs["vocab_file"] = "./tokenizer.model"
init_kwargs["added_tokens_file"] = None init_kwargs["added_tokens_file"] = None
init_kwargs["special_tokens_map_file"] = None init_kwargs["special_tokens_map_file"] = None
init_kwargs["tokenizer_file"] = None init_kwargs["tokenizer_file"] = None
@ -30,7 +35,7 @@ b = tokenizer.decode([236, 173, 140])
token = [] token = []
for i in range(64798): for i in range(64798):
token.append(str(i) + " : " + tokenizer.decode(i)) token.append(str(i) + " : " + tokenizer.decode(i))
show.DumpListToFile(token, "generated/token.log") show.DumpListToFile(token, "../generated/token.log")
# print("=======================") # print("=======================")
# for i in range(hidden_states_en.shape[0]): # for i in range(hidden_states_en.shape[0]):

25
qwen/demo.py Normal file
View File

@ -0,0 +1,25 @@
from modelscope import snapshot_download
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
model_dir = snapshot_download("qwen/Qwen-1_8B-Chat")
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_dir, device_map="auto", trust_remote_code=True
).eval()
# 可指定不同的生成长度、top_p等相关超参
model.generation_config = GenerationConfig.from_pretrained(
model_dir, trust_remote_code=True
)
# 第一轮对话
response, history = model.chat(tokenizer, "你好", history=None)
print(response)
# 你好!很高兴为你提供帮助。
# 第二轮对话
response, history = model.chat(tokenizer, "给我讲一个年轻人奋斗创业最终取得成功的故事。", history=history)
print(response)

1363
qwen/modeling_qwen.py Normal file

File diff suppressed because it is too large Load Diff

45
train.py Normal file
View File

@ -0,0 +1,45 @@
import json
import torch
from chatglm import ChatGLMForConditionalGeneration
from chatglm import ChatGLMTokenizer
from tools import show
from transformers import AutoConfig
seed = 4321
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
pretrained_model_name_or_path = "../ZhipuAI/chatglm3-6b"
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path,
return_unused_kwargs=True,
trust_remote_code=True,
code_revision=None,
_commit_hash=None,
)
glm = ChatGLMForConditionalGeneration(config)
tokenizer_config_file = "./chatglm/tokenizer_config.json"
if tokenizer_config_file is not None:
with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
init_kwargs = json.load(tokenizer_config_handle)
init_kwargs.pop("tokenizer_class", None)
init_kwargs.pop("tokenizer_file", None)
saved_init_inputs = init_kwargs.pop("init_inputs", ())
init_inputs = saved_init_inputs
init_kwargs["vocab_file"] = "./chatglm/tokenizer.model"
init_kwargs["added_tokens_file"] = None
init_kwargs["special_tokens_map_file"] = None
init_kwargs["tokenizer_file"] = None
init_kwargs["name_or_path"] = pretrained_model_name_or_path
tokenizer = ChatGLMTokenizer(*init_inputs, **init_kwargs)
glm = glm.from_pretrained(pretrained_model_name_or_path).half().cuda()
query = "你好"
response = glm.backward(tokenizer, query)