Update code.
This commit is contained in:
parent
ebe48f8efc
commit
fa7078b72d
61
Readme.md
61
Readme.md
|
@ -33,3 +33,64 @@ variance = hidden_states.pow(2).mean(-1, keepdim=True) -> [6, 1, 1]
|
||||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps) 平方根倒数
|
hidden_states = hidden_states * torch.rsqrt(variance + self.eps) 平方根倒数
|
||||||
self.weight -> [4096]
|
self.weight -> [4096]
|
||||||
return (self.weight * hidden_states) -> [6, 1, 4096]
|
return (self.weight * hidden_states) -> [6, 1, 4096]
|
||||||
|
|
||||||
|
## MLP
|
||||||
|
|
||||||
|
Linear(hidden_states) no bias -> [6, 1, 27392]
|
||||||
|
silu (x) = [6, 1, 13696] * sigmoid([6, 1, 13696])
|
||||||
|
Linear(intermediate_parallel) no bias -> [6, 1, 4096]
|
||||||
|
|
||||||
|
## core_attention
|
||||||
|
|
||||||
|
query_layer=query_layer.permute(1, 2, 0, 3) -> [1, 32, 6, 128]
|
||||||
|
key_layer=key_layer.permute(1, 2, 0, 3) -> [1, 32, 6, 128]
|
||||||
|
value_layer=value_layer.permute(1, 2, 0, 3) -> [1, 32, 6, 128]
|
||||||
|
context_layer = scaled_dot_product_attention(query_layer, key_layer, value_layer) -> [1, 32, 6, 128]
|
||||||
|
softmax(QK^T/sqrt(in_dim))V
|
||||||
|
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
||||||
|
att = F.softmax(att, dim=-1)
|
||||||
|
y = att @ v -> (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
||||||
|
context_layer = context_layer.permute(2, 0, 1, 3)
|
||||||
|
context_layer = context_layer.reshape() -> [6, 1, 4096]
|
||||||
|
|
||||||
|
## self_attention
|
||||||
|
|
||||||
|
hidden_states: [s, b, h]
|
||||||
|
mixed_x_layer = Linear(hidden_states) -> [6, 1, 4608] 4608:4096+256+256
|
||||||
|
|
||||||
|
(query_layer, key_layer, value_layer) = mixed_x_layer.split -> [6, 1, 4096], [6, 1, 256], [6, 1, 256]
|
||||||
|
query_layer = query_layer.view -> [6, 1, 32, 128]
|
||||||
|
key_layer = key_layer.view -> [6, 1, 2, 128]
|
||||||
|
value_layer = value_layer.view -> [6, 1, 2, 128]
|
||||||
|
|
||||||
|
query_layer = self.apply_rotary_pos_emb(query_layer, rotary_pos_emb)
|
||||||
|
key_layer = self.apply_rotary_pos_emb(key_layer, rotary_pos_emb)
|
||||||
|
|
||||||
|
key_layer = key_layer.unsqueeze(-2) -> [6, 1, 2, 1, 128]
|
||||||
|
key_layer = key_layer.expand -> [6, 1, 2, 16, 128]
|
||||||
|
key_layer = key_layer.contiguous().view -> [6, 1, 32, 128]
|
||||||
|
|
||||||
|
value_layer = value_layer.unsqueeze(-2) -> [6, 1, 2, 1, 128]
|
||||||
|
value_layer = value_layer.expand -> [6, 1, 2, 16, 128]
|
||||||
|
value_layer = value_layer.contiguous().view -> [6, 1, 32, 128]
|
||||||
|
|
||||||
|
context_layer = self.core_attention(query_layer, key_layer, value_layer) -> [6, 1, 4096]
|
||||||
|
return Linear(context_layer) -> [6, 1, 4096]
|
||||||
|
|
||||||
|
## GLMBlock
|
||||||
|
|
||||||
|
input
|
||||||
|
| \
|
||||||
|
| RMSNorm
|
||||||
|
| self_attention
|
||||||
|
| dropout
|
||||||
|
| /
|
||||||
|
Add
|
||||||
|
| \
|
||||||
|
| RMSNorm
|
||||||
|
| mlp
|
||||||
|
| dropout
|
||||||
|
| /
|
||||||
|
Add
|
||||||
|
|
||||||
|
所有的输出shape都是[6, 1, 4096], 6:sequence_length 1:batch_num 4096:hidden_size
|
|
@ -4,6 +4,7 @@ import copy
|
||||||
import os
|
import os
|
||||||
import gc
|
import gc
|
||||||
import json
|
import json
|
||||||
|
import hashlib
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
|
@ -148,28 +149,20 @@ class SelfAttention(torch.nn.Module):
|
||||||
dtype=config.torch_dtype,
|
dtype=config.torch_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
def apply_rotary_pos_emb(
|
def apply_rotary_pos_emb(self, x: torch.Tensor, rope: torch.Tensor) -> torch.Tensor:
|
||||||
self, x: torch.Tensor, rope_cache: torch.Tensor
|
|
||||||
) -> torch.Tensor:
|
|
||||||
# x: [sq, b, np, hn]
|
# x: [sq, b, np, hn]
|
||||||
sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
|
sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
|
||||||
rot_dim = rope_cache.shape[-2] * 2
|
if rope.size(0) != sq:
|
||||||
x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
|
raise ("Error rotary_pos_emb size")
|
||||||
# truncate to support variable sizes
|
x_rope = x[..., : hn // 2]
|
||||||
rope_cache = rope_cache[:sq]
|
x_pass = x[..., hn // 2 :]
|
||||||
xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
|
x_rope = x_rope.reshape(sq, -1, np, hn // 4, 1, 2)
|
||||||
rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
|
rope = rope.view(sq, -1, 1, hn // 4, 1, 2)
|
||||||
x_out2 = torch.stack(
|
roped1 = x_rope[..., 0] * rope[..., 0] - x_rope[..., 1] * rope[..., 1]
|
||||||
[
|
roped2 = x_rope[..., 1] * rope[..., 0] + x_rope[..., 0] * rope[..., 1]
|
||||||
xshaped[..., 0] * rope_cache[..., 0]
|
x_out = torch.cat((roped1, roped2), -1)
|
||||||
- xshaped[..., 1] * rope_cache[..., 1],
|
x_out = x_out.flatten(3)
|
||||||
xshaped[..., 1] * rope_cache[..., 0]
|
return torch.cat((x_out, x_pass), dim=-1)
|
||||||
+ xshaped[..., 0] * rope_cache[..., 1],
|
|
||||||
],
|
|
||||||
-1,
|
|
||||||
)
|
|
||||||
x_out2 = x_out2.flatten(3)
|
|
||||||
return torch.cat((x_out2, x_pass), dim=-1)
|
|
||||||
|
|
||||||
def forward(self, hidden_states, rotary_pos_emb):
|
def forward(self, hidden_states, rotary_pos_emb):
|
||||||
# hidden_states: [sq, b, h]
|
# hidden_states: [sq, b, h]
|
||||||
|
|
12
demo.py
12
demo.py
|
@ -1,11 +1,15 @@
|
||||||
import json
|
import json
|
||||||
|
import torch
|
||||||
|
|
||||||
from chatglm import ChatGLMForConditionalGeneration
|
from chatglm import ChatGLMForConditionalGeneration
|
||||||
from chatglm import ChatGLMTokenizer
|
from chatglm import ChatGLMTokenizer
|
||||||
|
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
|
|
||||||
|
seed = 1234
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
pretrained_model_name_or_path = "../ZhipuAI/chatglm3-6b"
|
pretrained_model_name_or_path = "../ZhipuAI/chatglm3-6b"
|
||||||
config, kwargs = AutoConfig.from_pretrained(
|
config, kwargs = AutoConfig.from_pretrained(
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
|
@ -38,9 +42,15 @@ glm = glm.eval()
|
||||||
query = "colin"
|
query = "colin"
|
||||||
response, history = glm.chat(tokenizer, query, history=[])
|
response, history = glm.chat(tokenizer, query, history=[])
|
||||||
print(response)
|
print(response)
|
||||||
|
if response[1:] != " Hello! How can I assist you today":
|
||||||
|
raise ()
|
||||||
|
|
||||||
query = "你好"
|
query = "你好"
|
||||||
response, history = glm.chat(tokenizer, query, history=history)
|
response, history = glm.chat(tokenizer, query, history=history)
|
||||||
print(response)
|
print(response)
|
||||||
|
if response[1:] != " 你好!有什么我可以帮助你的吗":
|
||||||
|
raise ()
|
||||||
|
|
||||||
# response, history = glm.chat(tokenizer, "你是一个心理学专家,请问晚上睡不着应该怎么办", history=history)
|
# response, history = glm.chat(tokenizer, "你是一个心理学专家,请问晚上睡不着应该怎么办", history=history)
|
||||||
# print(response)
|
# print(response)
|
||||||
|
|
||||||
|
|
13
tensor.py
13
tensor.py
|
@ -30,7 +30,20 @@ print()
|
||||||
print(x.unsqueeze(1).shape)
|
print(x.unsqueeze(1).shape)
|
||||||
print(x.unsqueeze(1).squeeze(1).shape)
|
print(x.unsqueeze(1).squeeze(1).shape)
|
||||||
|
|
||||||
|
|
||||||
x = torch.tensor([[1, 2], [3, 4]]).to(float)
|
x = torch.tensor([[1, 2], [3, 4]]).to(float)
|
||||||
print(x.mean(1))
|
print(x.mean(1))
|
||||||
print(x.mean(0))
|
print(x.mean(0))
|
||||||
print(x.mean(0, keepdim=True))
|
print(x.mean(0, keepdim=True))
|
||||||
|
|
||||||
|
print()
|
||||||
|
print()
|
||||||
|
x = torch.tensor([[1, 2], [3, 4]])
|
||||||
|
print(x.flatten(0))
|
||||||
|
|
||||||
|
x = torch.tensor([[1, 2], [3, 4]])
|
||||||
|
print(torch.stack((x, x), 1))
|
||||||
|
print(torch.cat((x, x), 1))
|
||||||
|
# So if A and B are of shape (3, 4):
|
||||||
|
# torch.cat([A, B], dim=0) will be of shape (6, 4)
|
||||||
|
# torch.stack([A, B], dim=0) will be of shape (2, 3, 4)
|
||||||
|
|
Loading…
Reference in New Issue