Update code.
This commit is contained in:
parent
185caa12e9
commit
10268c4414
|
@ -0,0 +1,23 @@
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## data flow
|
||||||
|
|
||||||
|
input_ids = tokenizer.build_chat_input(query, history=history, role=role)
|
||||||
|
|
||||||
|
input_ids -> [1, 6]
|
||||||
|
inputs_embeds -> [6, 1, 4096] 4096:hidden_size
|
||||||
|
rotary_pos_emb -> [6, 1, 32, 2] 32:pos的编码维度 2:cos+sin
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
for layers : GLMBlock(hidden_states, rotary_pos_emb)
|
||||||
|
hidden_states = self.final_layernorm(hidden_states)
|
||||||
|
hidden_states = hidden_states[-1:]
|
||||||
|
lm_logits = self.output_layer(hidden_states)
|
||||||
|
lm_logits = lm_logits.transpose(0, 1).contiguous() -> [1, 1, 65024]
|
||||||
|
|
||||||
|
probs = softmax(lm_logits) -> [1, 65024]
|
||||||
|
next_tokens = torch.multinomial(probs, num_samples=1) 采样 -> [1]
|
||||||
|
input_ids = torch.cat([input_ids, next_tokens) -> [1, 7]
|
||||||
|
|
||||||
|
response = tokenizer.decode(outputs)
|
|
@ -170,7 +170,7 @@ class SelfAttention(torch.nn.Module):
|
||||||
x_out2 = x_out2.flatten(3)
|
x_out2 = x_out2.flatten(3)
|
||||||
return torch.cat((x_out2, x_pass), dim=-1)
|
return torch.cat((x_out2, x_pass), dim=-1)
|
||||||
|
|
||||||
def forward(self, hidden_states, rotary_pos_emb, kv_cache=None):
|
def forward(self, hidden_states, rotary_pos_emb):
|
||||||
# hidden_states: [sq, b, h]
|
# hidden_states: [sq, b, h]
|
||||||
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
|
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
|
||||||
mixed_x_layer = self.query_key_value(hidden_states)
|
mixed_x_layer = self.query_key_value(hidden_states)
|
||||||
|
@ -213,8 +213,6 @@ class SelfAttention(torch.nn.Module):
|
||||||
query_layer = self.apply_rotary_pos_emb(query_layer, rotary_pos_emb)
|
query_layer = self.apply_rotary_pos_emb(query_layer, rotary_pos_emb)
|
||||||
key_layer = self.apply_rotary_pos_emb(key_layer, rotary_pos_emb)
|
key_layer = self.apply_rotary_pos_emb(key_layer, rotary_pos_emb)
|
||||||
|
|
||||||
kv_cache = (key_layer, value_layer)
|
|
||||||
|
|
||||||
key_layer = key_layer.unsqueeze(-2)
|
key_layer = key_layer.unsqueeze(-2)
|
||||||
key_layer = key_layer.expand(
|
key_layer = key_layer.expand(
|
||||||
-1,
|
-1,
|
||||||
|
@ -255,7 +253,7 @@ class SelfAttention(torch.nn.Module):
|
||||||
# Output. [sq, b, h]
|
# Output. [sq, b, h]
|
||||||
# =================
|
# =================
|
||||||
output = self.dense(context_layer)
|
output = self.dense(context_layer)
|
||||||
return output, kv_cache
|
return output
|
||||||
|
|
||||||
|
|
||||||
class MLP(torch.nn.Module):
|
class MLP(torch.nn.Module):
|
||||||
|
@ -342,14 +340,12 @@ class GLMBlock(torch.nn.Module):
|
||||||
# MLP
|
# MLP
|
||||||
self.mlp = MLP(config, device=device)
|
self.mlp = MLP(config, device=device)
|
||||||
|
|
||||||
def forward(self, hidden_states, rotary_pos_emb, kv_cache=None):
|
def forward(self, hidden_states, rotary_pos_emb):
|
||||||
# hidden_states: [s, b, h]
|
# hidden_states: [s, b, h]
|
||||||
# Layer norm at the beginning of the transformer layer.
|
# Layer norm at the beginning of the transformer layer.
|
||||||
layernorm_output = self.input_layernorm(hidden_states)
|
layernorm_output = self.input_layernorm(hidden_states)
|
||||||
# Self attention.
|
# Self attention.
|
||||||
attention_output, kv_cache = self.self_attention(
|
attention_output = self.self_attention(layernorm_output, rotary_pos_emb)
|
||||||
layernorm_output, rotary_pos_emb, kv_cache=kv_cache
|
|
||||||
)
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
layernorm_input = torch.nn.functional.dropout(
|
layernorm_input = torch.nn.functional.dropout(
|
||||||
|
@ -369,7 +365,7 @@ class GLMBlock(torch.nn.Module):
|
||||||
mlp_output, p=self.hidden_dropout, training=self.training
|
mlp_output, p=self.hidden_dropout, training=self.training
|
||||||
)
|
)
|
||||||
output = residual + output
|
output = residual + output
|
||||||
return output, kv_cache
|
return output
|
||||||
|
|
||||||
|
|
||||||
class GLMTransformer(torch.nn.Module):
|
class GLMTransformer(torch.nn.Module):
|
||||||
|
@ -389,18 +385,10 @@ class GLMTransformer(torch.nn.Module):
|
||||||
dtype=config.torch_dtype,
|
dtype=config.torch_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(self, hidden_states, rotary_pos_emb):
|
||||||
self,
|
|
||||||
hidden_states,
|
|
||||||
rotary_pos_emb
|
|
||||||
):
|
|
||||||
kv_caches = [None for _ in range(self.num_layers)]
|
|
||||||
|
|
||||||
for index in range(self.num_layers):
|
for index in range(self.num_layers):
|
||||||
layer = self.layers[index]
|
layer = self.layers[index]
|
||||||
hidden_states, kv_cache = layer(
|
hidden_states = layer(hidden_states, rotary_pos_emb)
|
||||||
hidden_states, rotary_pos_emb, kv_cache=kv_caches[index]
|
|
||||||
)
|
|
||||||
hidden_states = self.final_layernorm(hidden_states)
|
hidden_states = self.final_layernorm(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
@ -469,27 +457,20 @@ class ChatGLMModel(nn.Module):
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids: Optional[torch.Tensor] = None,
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_last_logit: Optional[bool] = False,
|
|
||||||
):
|
):
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states
|
output_hidden_states
|
||||||
if output_hidden_states is not None
|
if output_hidden_states is not None
|
||||||
else self.config.output_hidden_states
|
else self.config.output_hidden_states
|
||||||
)
|
)
|
||||||
batch_size, seq_length = input_ids.shape
|
|
||||||
inputs_embeds = self.embedding(input_ids)
|
inputs_embeds = self.embedding(input_ids)
|
||||||
|
|
||||||
# Rotary positional embeddings
|
|
||||||
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
|
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
|
||||||
# show.DumpTensorToImage(rotary_pos_emb[:, :, 0], "rotary_pos_emb.png", scale=0.1)
|
# show.DumpTensorToImage(rotary_pos_emb[:, :, 0], "rotary_pos_emb.png", scale=0.1)
|
||||||
|
|
||||||
rotary_pos_emb = rotary_pos_emb[position_ids]
|
rotary_pos_emb = rotary_pos_emb[position_ids]
|
||||||
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
|
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
|
||||||
hidden_states = self.encoder(
|
hidden_states = self.encoder(inputs_embeds, rotary_pos_emb)
|
||||||
inputs_embeds,
|
|
||||||
rotary_pos_emb=rotary_pos_emb
|
|
||||||
)
|
|
||||||
if return_last_logit:
|
|
||||||
hidden_states = hidden_states[-1:]
|
hidden_states = hidden_states[-1:]
|
||||||
lm_logits = self.output_layer(hidden_states)
|
lm_logits = self.output_layer(hidden_states)
|
||||||
lm_logits = lm_logits.transpose(0, 1).contiguous()
|
lm_logits = lm_logits.transpose(0, 1).contiguous()
|
||||||
|
@ -676,7 +657,7 @@ class ChatGLMForConditionalGeneration(nn.Module):
|
||||||
input_ids,
|
input_ids,
|
||||||
pad_token_id=generation_config.pad_token_id,
|
pad_token_id=generation_config.pad_token_id,
|
||||||
eos_token_id=generation_config.eos_token_id,
|
eos_token_id=generation_config.eos_token_id,
|
||||||
output_hidden_states=generation_config.output_hidden_states
|
output_hidden_states=generation_config.output_hidden_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) : -1]
|
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) : -1]
|
||||||
|
@ -689,7 +670,7 @@ class ChatGLMForConditionalGeneration(nn.Module):
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
pad_token_id: Optional[int] = None,
|
pad_token_id: Optional[int] = None,
|
||||||
eos_token_id: Optional[Union[int, List[int]]] = None,
|
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||||
output_hidden_states: Optional[bool] = None
|
output_hidden_states: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
if isinstance(eos_token_id, int):
|
if isinstance(eos_token_id, int):
|
||||||
eos_token_id = [eos_token_id]
|
eos_token_id = [eos_token_id]
|
||||||
|
@ -708,11 +689,7 @@ class ChatGLMForConditionalGeneration(nn.Module):
|
||||||
.unsqueeze(0)
|
.unsqueeze(0)
|
||||||
.repeat(batch_size, 1)
|
.repeat(batch_size, 1)
|
||||||
)
|
)
|
||||||
model_inputs = {
|
model_inputs = {"input_ids": input_ids_in, "position_ids": position_ids_in}
|
||||||
"input_ids": input_ids_in,
|
|
||||||
"position_ids": position_ids_in,
|
|
||||||
"return_last_logit": True
|
|
||||||
}
|
|
||||||
|
|
||||||
logits = self.transformer(
|
logits = self.transformer(
|
||||||
**model_inputs,
|
**model_inputs,
|
||||||
|
@ -723,7 +700,6 @@ class ChatGLMForConditionalGeneration(nn.Module):
|
||||||
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||||
|
|
||||||
# finished sentences should have their next token be a padding token
|
# finished sentences should have their next token be a padding token
|
||||||
if eos_token_id is not None:
|
|
||||||
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
|
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
|
||||||
1 - unfinished_sequences
|
1 - unfinished_sequences
|
||||||
)
|
)
|
||||||
|
@ -732,15 +708,12 @@ class ChatGLMForConditionalGeneration(nn.Module):
|
||||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||||
|
|
||||||
# if eos_token was found in one sentence, set sentence to finished
|
# if eos_token was found in one sentence, set sentence to finished
|
||||||
if eos_token_id_tensor is not None:
|
|
||||||
unfinished_sequences = unfinished_sequences.mul(
|
unfinished_sequences = unfinished_sequences.mul(
|
||||||
next_tokens.tile(eos_token_id_tensor.shape[0], 1)
|
next_tokens.tile(eos_token_id_tensor.shape[0], 1)
|
||||||
.ne(eos_token_id_tensor.unsqueeze(1))
|
.ne(eos_token_id_tensor.unsqueeze(1))
|
||||||
.prod(dim=0)
|
.prod(dim=0)
|
||||||
)
|
)
|
||||||
if unfinished_sequences.max() == 0:
|
if unfinished_sequences.max() == 0:
|
||||||
this_peer_finished = True
|
|
||||||
if this_peer_finished:
|
|
||||||
break
|
break
|
||||||
|
|
||||||
return input_ids
|
return input_ids
|
||||||
|
|
9
demo.py
9
demo.py
|
@ -25,7 +25,7 @@ if tokenizer_config_file is not None:
|
||||||
init_kwargs.pop("tokenizer_file", None)
|
init_kwargs.pop("tokenizer_file", None)
|
||||||
saved_init_inputs = init_kwargs.pop("init_inputs", ())
|
saved_init_inputs = init_kwargs.pop("init_inputs", ())
|
||||||
init_inputs = saved_init_inputs
|
init_inputs = saved_init_inputs
|
||||||
init_kwargs["vocab_file"] = './chatglm/tokenizer.model'
|
init_kwargs["vocab_file"] = "./chatglm/tokenizer.model"
|
||||||
init_kwargs["added_tokens_file"] = None
|
init_kwargs["added_tokens_file"] = None
|
||||||
init_kwargs["special_tokens_map_file"] = None
|
init_kwargs["special_tokens_map_file"] = None
|
||||||
init_kwargs["tokenizer_file"] = None
|
init_kwargs["tokenizer_file"] = None
|
||||||
|
@ -35,9 +35,11 @@ tokenizer = ChatGLMTokenizer(*init_inputs, **init_kwargs)
|
||||||
|
|
||||||
glm = glm.from_pretrained(pretrained_model_name_or_path, config=config).half().cuda()
|
glm = glm.from_pretrained(pretrained_model_name_or_path, config=config).half().cuda()
|
||||||
glm = glm.eval()
|
glm = glm.eval()
|
||||||
response, history = glm.chat(tokenizer, "colin", history=[])
|
query = "colin"
|
||||||
|
response, history = glm.chat(tokenizer, query, history=[])
|
||||||
print(response)
|
print(response)
|
||||||
response, history = glm.chat(tokenizer, "你好", history=history)
|
query = "你好"
|
||||||
|
response, history = glm.chat(tokenizer, query, history=history)
|
||||||
print(response)
|
print(response)
|
||||||
# response, history = glm.chat(tokenizer, "你是一个心理学专家,请问晚上睡不着应该怎么办", history=history)
|
# response, history = glm.chat(tokenizer, "你是一个心理学专家,请问晚上睡不着应该怎么办", history=history)
|
||||||
# print(response)
|
# print(response)
|
||||||
|
@ -50,7 +52,6 @@ print(response)
|
||||||
# px.scatter(gapminder2007, x='gdpPercap', y='lifeExp')
|
# px.scatter(gapminder2007, x='gdpPercap', y='lifeExp')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# from modelscope import AutoTokenizer, AutoModel, snapshot_download
|
# from modelscope import AutoTokenizer, AutoModel, snapshot_download
|
||||||
# model_dir = snapshot_download("ZhipuAI/chatglm3-6b", cache_dir="./chatglm", revision="v1.0.0")
|
# model_dir = snapshot_download("ZhipuAI/chatglm3-6b", cache_dir="./chatglm", revision="v1.0.0")
|
||||||
# model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).half().cuda()
|
# model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).half().cuda()
|
||||||
|
|
|
@ -0,0 +1,28 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
# 定义词表大小和向量维度
|
||||||
|
vocab_size = 10000
|
||||||
|
embedding_dim = 16
|
||||||
|
|
||||||
|
# 定义一个Embedding层
|
||||||
|
embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)
|
||||||
|
|
||||||
|
# 定义一个输入张量,形状为(batch_size, sequence_length)
|
||||||
|
input_tensor = torch.LongTensor([[1, 2], [4, 3]])
|
||||||
|
|
||||||
|
# 将输入张量传递给Embedding层
|
||||||
|
embedded_tensor = embedding(input_tensor)
|
||||||
|
|
||||||
|
|
||||||
|
print("embedded weight shape:")
|
||||||
|
print(embedding.weight.shape)
|
||||||
|
print("embedded weight:")
|
||||||
|
print(embedding.weight)
|
||||||
|
|
||||||
|
|
||||||
|
# 输出形状为 (batch_size, sequence_length, embedding_dim)
|
||||||
|
print("embedded out shape:")
|
||||||
|
print(embedded_tensor.shape)
|
||||||
|
print("embedded out:")
|
||||||
|
print(embedded_tensor)
|
Loading…
Reference in New Issue