Update code.

This commit is contained in:
Colin 2023-12-25 16:22:45 +08:00
parent ebe48f8efc
commit fa7078b72d
4 changed files with 99 additions and 22 deletions

View File

@ -33,3 +33,64 @@ variance = hidden_states.pow(2).mean(-1, keepdim=True) -> [6, 1, 1]
hidden_states = hidden_states * torch.rsqrt(variance + self.eps) 平方根倒数 hidden_states = hidden_states * torch.rsqrt(variance + self.eps) 平方根倒数
self.weight -> [4096] self.weight -> [4096]
return (self.weight * hidden_states) -> [6, 1, 4096] return (self.weight * hidden_states) -> [6, 1, 4096]
## MLP
Linear(hidden_states) no bias -> [6, 1, 27392]
silu (x) = [6, 1, 13696] * sigmoid([6, 1, 13696])
Linear(intermediate_parallel) no bias -> [6, 1, 4096]
## core_attention
query_layer=query_layer.permute(1, 2, 0, 3) -> [1, 32, 6, 128]
key_layer=key_layer.permute(1, 2, 0, 3) -> [1, 32, 6, 128]
value_layer=value_layer.permute(1, 2, 0, 3) -> [1, 32, 6, 128]
context_layer = scaled_dot_product_attention(query_layer, key_layer, value_layer) -> [1, 32, 6, 128]
softmax(QK^T/sqrt(in_dim))V
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = F.softmax(att, dim=-1)
y = att @ v -> (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
context_layer = context_layer.permute(2, 0, 1, 3)
context_layer = context_layer.reshape() -> [6, 1, 4096]
## self_attention
hidden_states: [s, b, h]
mixed_x_layer = Linear(hidden_states) -> [6, 1, 4608] 4608:4096+256+256
(query_layer, key_layer, value_layer) = mixed_x_layer.split -> [6, 1, 4096], [6, 1, 256], [6, 1, 256]
query_layer = query_layer.view -> [6, 1, 32, 128]
key_layer = key_layer.view -> [6, 1, 2, 128]
value_layer = value_layer.view -> [6, 1, 2, 128]
query_layer = self.apply_rotary_pos_emb(query_layer, rotary_pos_emb)
key_layer = self.apply_rotary_pos_emb(key_layer, rotary_pos_emb)
key_layer = key_layer.unsqueeze(-2) -> [6, 1, 2, 1, 128]
key_layer = key_layer.expand -> [6, 1, 2, 16, 128]
key_layer = key_layer.contiguous().view -> [6, 1, 32, 128]
value_layer = value_layer.unsqueeze(-2) -> [6, 1, 2, 1, 128]
value_layer = value_layer.expand -> [6, 1, 2, 16, 128]
value_layer = value_layer.contiguous().view -> [6, 1, 32, 128]
context_layer = self.core_attention(query_layer, key_layer, value_layer) -> [6, 1, 4096]
return Linear(context_layer) -> [6, 1, 4096]
## GLMBlock
input
| \
| RMSNorm
| self_attention
| dropout
| /
Add
| \
| RMSNorm
| mlp
| dropout
| /
Add
所有的输出shape都是[6, 1, 4096], 6:sequence_length 1:batch_num 4096:hidden_size

View File

@ -4,6 +4,7 @@ import copy
import os import os
import gc import gc
import json import json
import hashlib
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
@ -148,28 +149,20 @@ class SelfAttention(torch.nn.Module):
dtype=config.torch_dtype, dtype=config.torch_dtype,
) )
def apply_rotary_pos_emb( def apply_rotary_pos_emb(self, x: torch.Tensor, rope: torch.Tensor) -> torch.Tensor:
self, x: torch.Tensor, rope_cache: torch.Tensor
) -> torch.Tensor:
# x: [sq, b, np, hn] # x: [sq, b, np, hn]
sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3) sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
rot_dim = rope_cache.shape[-2] * 2 if rope.size(0) != sq:
x, x_pass = x[..., :rot_dim], x[..., rot_dim:] raise ("Error rotary_pos_emb size")
# truncate to support variable sizes x_rope = x[..., : hn // 2]
rope_cache = rope_cache[:sq] x_pass = x[..., hn // 2 :]
xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) x_rope = x_rope.reshape(sq, -1, np, hn // 4, 1, 2)
rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) rope = rope.view(sq, -1, 1, hn // 4, 1, 2)
x_out2 = torch.stack( roped1 = x_rope[..., 0] * rope[..., 0] - x_rope[..., 1] * rope[..., 1]
[ roped2 = x_rope[..., 1] * rope[..., 0] + x_rope[..., 0] * rope[..., 1]
xshaped[..., 0] * rope_cache[..., 0] x_out = torch.cat((roped1, roped2), -1)
- xshaped[..., 1] * rope_cache[..., 1], x_out = x_out.flatten(3)
xshaped[..., 1] * rope_cache[..., 0] return torch.cat((x_out, x_pass), dim=-1)
+ xshaped[..., 0] * rope_cache[..., 1],
],
-1,
)
x_out2 = x_out2.flatten(3)
return torch.cat((x_out2, x_pass), dim=-1)
def forward(self, hidden_states, rotary_pos_emb): def forward(self, hidden_states, rotary_pos_emb):
# hidden_states: [sq, b, h] # hidden_states: [sq, b, h]

12
demo.py
View File

@ -1,11 +1,15 @@
import json import json
import torch
from chatglm import ChatGLMForConditionalGeneration from chatglm import ChatGLMForConditionalGeneration
from chatglm import ChatGLMTokenizer from chatglm import ChatGLMTokenizer
from transformers import AutoConfig from transformers import AutoConfig
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
pretrained_model_name_or_path = "../ZhipuAI/chatglm3-6b" pretrained_model_name_or_path = "../ZhipuAI/chatglm3-6b"
config, kwargs = AutoConfig.from_pretrained( config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, pretrained_model_name_or_path,
@ -38,9 +42,15 @@ glm = glm.eval()
query = "colin" query = "colin"
response, history = glm.chat(tokenizer, query, history=[]) response, history = glm.chat(tokenizer, query, history=[])
print(response) print(response)
if response[1:] != " Hello! How can I assist you today":
raise ()
query = "你好" query = "你好"
response, history = glm.chat(tokenizer, query, history=history) response, history = glm.chat(tokenizer, query, history=history)
print(response) print(response)
if response[1:] != " 你好!有什么我可以帮助你的吗":
raise ()
# response, history = glm.chat(tokenizer, "你是一个心理学专家,请问晚上睡不着应该怎么办", history=history) # response, history = glm.chat(tokenizer, "你是一个心理学专家,请问晚上睡不着应该怎么办", history=history)
# print(response) # print(response)

View File

@ -30,7 +30,20 @@ print()
print(x.unsqueeze(1).shape) print(x.unsqueeze(1).shape)
print(x.unsqueeze(1).squeeze(1).shape) print(x.unsqueeze(1).squeeze(1).shape)
x = torch.tensor([[1, 2], [3, 4]]).to(float) x = torch.tensor([[1, 2], [3, 4]]).to(float)
print(x.mean(1)) print(x.mean(1))
print(x.mean(0)) print(x.mean(0))
print(x.mean(0, keepdim=True)) print(x.mean(0, keepdim=True))
print()
print()
x = torch.tensor([[1, 2], [3, 4]])
print(x.flatten(0))
x = torch.tensor([[1, 2], [3, 4]])
print(torch.stack((x, x), 1))
print(torch.cat((x, x), 1))
# So if A and B are of shape (3, 4):
# torch.cat([A, B], dim=0) will be of shape (6, 4)
# torch.stack([A, B], dim=0) will be of shape (2, 3, 4)