Update code.
This commit is contained in:
		
							parent
							
								
									ebe48f8efc
								
							
						
					
					
						commit
						fa7078b72d
					
				
							
								
								
									
										61
									
								
								Readme.md
								
								
								
								
							
							
						
						
									
										61
									
								
								Readme.md
								
								
								
								
							| 
						 | 
					@ -33,3 +33,64 @@ variance = hidden_states.pow(2).mean(-1, keepdim=True)  -> [6, 1, 1]
 | 
				
			||||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps) 平方根倒数
 | 
					hidden_states = hidden_states * torch.rsqrt(variance + self.eps) 平方根倒数
 | 
				
			||||||
self.weight -> [4096]
 | 
					self.weight -> [4096]
 | 
				
			||||||
return (self.weight * hidden_states)  -> [6, 1, 4096]
 | 
					return (self.weight * hidden_states)  -> [6, 1, 4096]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## MLP
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Linear(hidden_states)  no bias  ->  [6, 1, 27392]
 | 
				
			||||||
 | 
					silu (x) = [6, 1, 13696] * sigmoid([6, 1, 13696])
 | 
				
			||||||
 | 
					Linear(intermediate_parallel)  no bias  ->  [6, 1, 4096]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## core_attention
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					query_layer=query_layer.permute(1, 2, 0, 3)  ->  [1, 32, 6, 128]
 | 
				
			||||||
 | 
					key_layer=key_layer.permute(1, 2, 0, 3)  ->  [1, 32, 6, 128]
 | 
				
			||||||
 | 
					value_layer=value_layer.permute(1, 2, 0, 3)  ->  [1, 32, 6, 128]
 | 
				
			||||||
 | 
					context_layer = scaled_dot_product_attention(query_layer, key_layer, value_layer)  ->  [1, 32, 6, 128]
 | 
				
			||||||
 | 
					    softmax(QK^T/sqrt(in_dim))V
 | 
				
			||||||
 | 
					    att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
 | 
				
			||||||
 | 
					    att = F.softmax(att, dim=-1)
 | 
				
			||||||
 | 
					    y = att @ v  ->  (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
 | 
				
			||||||
 | 
					context_layer = context_layer.permute(2, 0, 1, 3)
 | 
				
			||||||
 | 
					context_layer = context_layer.reshape()  ->  [6, 1, 4096]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## self_attention
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					hidden_states: [s, b, h]
 | 
				
			||||||
 | 
					mixed_x_layer = Linear(hidden_states)  -> [6, 1, 4608]  4608:4096+256+256
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					(query_layer, key_layer, value_layer) = mixed_x_layer.split  -> [6, 1, 4096], [6, 1, 256], [6, 1, 256]  
 | 
				
			||||||
 | 
					query_layer = query_layer.view  ->  [6, 1, 32, 128]
 | 
				
			||||||
 | 
					key_layer = key_layer.view  ->  [6, 1, 2, 128]
 | 
				
			||||||
 | 
					value_layer = value_layer.view  ->  [6, 1, 2, 128]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					query_layer = self.apply_rotary_pos_emb(query_layer, rotary_pos_emb)
 | 
				
			||||||
 | 
					key_layer = self.apply_rotary_pos_emb(key_layer, rotary_pos_emb)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					key_layer = key_layer.unsqueeze(-2)  ->  [6, 1, 2, 1, 128]
 | 
				
			||||||
 | 
					key_layer = key_layer.expand  ->  [6, 1, 2, 16, 128]
 | 
				
			||||||
 | 
					key_layer = key_layer.contiguous().view  ->  [6, 1, 32, 128]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					value_layer = value_layer.unsqueeze(-2)  ->  [6, 1, 2, 1, 128]
 | 
				
			||||||
 | 
					value_layer = value_layer.expand  ->  [6, 1, 2, 16, 128]
 | 
				
			||||||
 | 
					value_layer = value_layer.contiguous().view  ->  [6, 1, 32, 128]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					context_layer = self.core_attention(query_layer, key_layer, value_layer)  ->  [6, 1, 4096]
 | 
				
			||||||
 | 
					return Linear(context_layer)  ->  [6, 1, 4096]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## GLMBlock
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 input
 | 
				
			||||||
 | 
					 |   \
 | 
				
			||||||
 | 
					 |   RMSNorm
 | 
				
			||||||
 | 
					 |   self_attention
 | 
				
			||||||
 | 
					 |   dropout
 | 
				
			||||||
 | 
					 |   /
 | 
				
			||||||
 | 
					 Add
 | 
				
			||||||
 | 
					 |  \
 | 
				
			||||||
 | 
					 |  RMSNorm
 | 
				
			||||||
 | 
					 |  mlp
 | 
				
			||||||
 | 
					 |  dropout
 | 
				
			||||||
 | 
					 |  /
 | 
				
			||||||
 | 
					 Add
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					所有的输出shape都是[6, 1, 4096], 6:sequence_length  1:batch_num  4096:hidden_size
 | 
				
			||||||
| 
						 | 
					@ -4,6 +4,7 @@ import copy
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
import gc
 | 
					import gc
 | 
				
			||||||
import json
 | 
					import json
 | 
				
			||||||
 | 
					import hashlib
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import torch.utils.checkpoint
 | 
					import torch.utils.checkpoint
 | 
				
			||||||
| 
						 | 
					@ -148,28 +149,20 @@ class SelfAttention(torch.nn.Module):
 | 
				
			||||||
            dtype=config.torch_dtype,
 | 
					            dtype=config.torch_dtype,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def apply_rotary_pos_emb(
 | 
					    def apply_rotary_pos_emb(self, x: torch.Tensor, rope: torch.Tensor) -> torch.Tensor:
 | 
				
			||||||
        self, x: torch.Tensor, rope_cache: torch.Tensor
 | 
					 | 
				
			||||||
    ) -> torch.Tensor:
 | 
					 | 
				
			||||||
        # x: [sq, b, np, hn]
 | 
					        # x: [sq, b, np, hn]
 | 
				
			||||||
        sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
 | 
					        sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
 | 
				
			||||||
        rot_dim = rope_cache.shape[-2] * 2
 | 
					        if rope.size(0) != sq:
 | 
				
			||||||
        x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
 | 
					            raise ("Error rotary_pos_emb size")
 | 
				
			||||||
        # truncate to support variable sizes
 | 
					        x_rope = x[..., : hn // 2]
 | 
				
			||||||
        rope_cache = rope_cache[:sq]
 | 
					        x_pass = x[..., hn // 2 :]
 | 
				
			||||||
        xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
 | 
					        x_rope = x_rope.reshape(sq, -1, np, hn // 4, 1, 2)
 | 
				
			||||||
        rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
 | 
					        rope = rope.view(sq, -1, 1, hn // 4, 1, 2)
 | 
				
			||||||
        x_out2 = torch.stack(
 | 
					        roped1 = x_rope[..., 0] * rope[..., 0] - x_rope[..., 1] * rope[..., 1]
 | 
				
			||||||
            [
 | 
					        roped2 = x_rope[..., 1] * rope[..., 0] + x_rope[..., 0] * rope[..., 1]
 | 
				
			||||||
                xshaped[..., 0] * rope_cache[..., 0]
 | 
					        x_out = torch.cat((roped1, roped2), -1)
 | 
				
			||||||
                - xshaped[..., 1] * rope_cache[..., 1],
 | 
					        x_out = x_out.flatten(3)
 | 
				
			||||||
                xshaped[..., 1] * rope_cache[..., 0]
 | 
					        return torch.cat((x_out, x_pass), dim=-1)
 | 
				
			||||||
                + xshaped[..., 0] * rope_cache[..., 1],
 | 
					 | 
				
			||||||
            ],
 | 
					 | 
				
			||||||
            -1,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        x_out2 = x_out2.flatten(3)
 | 
					 | 
				
			||||||
        return torch.cat((x_out2, x_pass), dim=-1)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self, hidden_states, rotary_pos_emb):
 | 
					    def forward(self, hidden_states, rotary_pos_emb):
 | 
				
			||||||
        # hidden_states: [sq, b, h]
 | 
					        # hidden_states: [sq, b, h]
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										12
									
								
								demo.py
								
								
								
								
							
							
						
						
									
										12
									
								
								demo.py
								
								
								
								
							| 
						 | 
					@ -1,11 +1,15 @@
 | 
				
			||||||
import json
 | 
					import json
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from chatglm import ChatGLMForConditionalGeneration
 | 
					from chatglm import ChatGLMForConditionalGeneration
 | 
				
			||||||
from chatglm import ChatGLMTokenizer
 | 
					from chatglm import ChatGLMTokenizer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from transformers import AutoConfig
 | 
					from transformers import AutoConfig
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					seed = 1234
 | 
				
			||||||
 | 
					torch.manual_seed(seed)
 | 
				
			||||||
 | 
					torch.cuda.manual_seed_all(seed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
pretrained_model_name_or_path = "../ZhipuAI/chatglm3-6b"
 | 
					pretrained_model_name_or_path = "../ZhipuAI/chatglm3-6b"
 | 
				
			||||||
config, kwargs = AutoConfig.from_pretrained(
 | 
					config, kwargs = AutoConfig.from_pretrained(
 | 
				
			||||||
    pretrained_model_name_or_path,
 | 
					    pretrained_model_name_or_path,
 | 
				
			||||||
| 
						 | 
					@ -38,9 +42,15 @@ glm = glm.eval()
 | 
				
			||||||
query = "colin"
 | 
					query = "colin"
 | 
				
			||||||
response, history = glm.chat(tokenizer, query, history=[])
 | 
					response, history = glm.chat(tokenizer, query, history=[])
 | 
				
			||||||
print(response)
 | 
					print(response)
 | 
				
			||||||
 | 
					if response[1:] != " Hello! How can I assist you today":
 | 
				
			||||||
 | 
					    raise ()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
query = "你好"
 | 
					query = "你好"
 | 
				
			||||||
response, history = glm.chat(tokenizer, query, history=history)
 | 
					response, history = glm.chat(tokenizer, query, history=history)
 | 
				
			||||||
print(response)
 | 
					print(response)
 | 
				
			||||||
 | 
					if response[1:] != " 你好!有什么我可以帮助你的吗":
 | 
				
			||||||
 | 
					    raise ()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# response, history = glm.chat(tokenizer, "你是一个心理学专家,请问晚上睡不着应该怎么办", history=history)
 | 
					# response, history = glm.chat(tokenizer, "你是一个心理学专家,请问晚上睡不着应该怎么办", history=history)
 | 
				
			||||||
# print(response)
 | 
					# print(response)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										13
									
								
								tensor.py
								
								
								
								
							
							
						
						
									
										13
									
								
								tensor.py
								
								
								
								
							| 
						 | 
					@ -30,7 +30,20 @@ print()
 | 
				
			||||||
print(x.unsqueeze(1).shape)
 | 
					print(x.unsqueeze(1).shape)
 | 
				
			||||||
print(x.unsqueeze(1).squeeze(1).shape)
 | 
					print(x.unsqueeze(1).squeeze(1).shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
x = torch.tensor([[1, 2], [3, 4]]).to(float)
 | 
					x = torch.tensor([[1, 2], [3, 4]]).to(float)
 | 
				
			||||||
print(x.mean(1))
 | 
					print(x.mean(1))
 | 
				
			||||||
print(x.mean(0))
 | 
					print(x.mean(0))
 | 
				
			||||||
print(x.mean(0, keepdim=True))
 | 
					print(x.mean(0, keepdim=True))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					print()
 | 
				
			||||||
 | 
					print()
 | 
				
			||||||
 | 
					x = torch.tensor([[1, 2], [3, 4]])
 | 
				
			||||||
 | 
					print(x.flatten(0))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					x = torch.tensor([[1, 2], [3, 4]])
 | 
				
			||||||
 | 
					print(torch.stack((x, x), 1))
 | 
				
			||||||
 | 
					print(torch.cat((x, x), 1))
 | 
				
			||||||
 | 
					# So if A and B are of shape (3, 4):
 | 
				
			||||||
 | 
					# torch.cat([A, B], dim=0) will be of shape (6, 4)
 | 
				
			||||||
 | 
					# torch.stack([A, B], dim=0) will be of shape (2, 3, 4)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue