From 79573867af7fffd73fef4e6d412d93617cd3c5b0 Mon Sep 17 00:00:00 2001 From: Colin Date: Sun, 21 Jan 2024 16:46:00 +0800 Subject: [PATCH] Refine qwen to module fomater. --- qwen/modeling_qwen.py | 70 +++++++++++++++++++----------------- qwen/research_attention.py | 73 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 111 insertions(+), 32 deletions(-) create mode 100644 qwen/research_attention.py diff --git a/qwen/modeling_qwen.py b/qwen/modeling_qwen.py index 958a200..8da5842 100644 --- a/qwen/modeling_qwen.py +++ b/qwen/modeling_qwen.py @@ -256,59 +256,65 @@ class QwenRunner: history.append((query, response)) return response, history, decoded - def forwardAttention( + def _rotate_half(self, x): + x = rearrange(x, "... (j d) -> ... j d", j=2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb(self, t, freqs): + rot_dim = freqs[0].shape[-1] + cos, sin = freqs + t_float = t.float() + t_rot, t_pass = t_float[..., :rot_dim], t_float[..., rot_dim:] + t_rot = (t_rot * cos) + (self._rotate_half(t_rot) * sin) + return torch.cat((t_rot, t_pass), dim=-1).type_as(t) + + def split_heads( self, attention, hidden_states: Optional[Tuple[torch.FloatTensor]], - rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None, ): - def apply_rotary_pos_emb(t, freqs): - def _rotate_half(x): - x = rearrange(x, "... (j d) -> ... j d", j=2) - x1, x2 = x.unbind(dim=-2) - return torch.cat((-x2, x1), dim=-1) - - rot_dim = freqs[0].shape[-1] - cos, sin = freqs - t_float = t.float() - t_rot, t_pass = t_float[..., :rot_dim], t_float[..., rot_dim:] - t_rot = (t_rot * cos) + (_rotate_half(t_rot) * sin) - return torch.cat((t_rot, t_pass), dim=-1).type_as(t) - atten = attention mixed_x_layer = atten.c_attn(hidden_states) query, key, value = mixed_x_layer.split(atten.split_size, dim=2) query = atten._split_heads(query, atten.num_heads, atten.head_dim) key = atten._split_heads(key, atten.num_heads, atten.head_dim) value = atten._split_heads(value, atten.num_heads, atten.head_dim) + return query, key, value + def pos_emb(self, query, key, rotary_pos_emb_list): rotary_pos_emb = rotary_pos_emb_list[0] rotary_pos_emb = [i[:, -query.shape[1] :, :, :] for i in rotary_pos_emb] rotary_pos_emb = (rotary_pos_emb,) * 2 - query = apply_rotary_pos_emb(query, rotary_pos_emb[0]) - key = apply_rotary_pos_emb(key, rotary_pos_emb[1]) + query = self.apply_rotary_pos_emb(query, rotary_pos_emb[0]) + key = self.apply_rotary_pos_emb(key, rotary_pos_emb[1]) + return query, key - key_size = key.size(1) - causal_mask = torch.tril(torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)).view( - 1, 1, key_size, key_size - ) + def attention(self, attention, query, key, value, causal_mask): query = query.permute(0, 2, 1, 3) key = key.permute(0, 2, 1, 3) value = value.permute(0, 2, 1, 3) - - # qk = query @ key.transpose(-2, -1) - # qk = qk[0] - # prePath = "../generated/query_matmul_key/img/" - # show.DumpTensorToImage( - # qk, prePath + "q_matmul_k_sequence_" + str(key_size) + "_layer_" + str(self.index) + ".png" - # ) - attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=causal_mask).transpose(1, 2) - context_layer = atten._merge_heads(attn_output, atten.num_heads, atten.head_dim) - attn_output = atten.c_proj(context_layer) - + context_layer = attention._merge_heads(attn_output, attention.num_heads, attention.head_dim) + attn_output = attention.c_proj(context_layer) return attn_output + def build_mask(self, query): + size = query.size(1) + causal_mask = torch.tril(torch.ones((size, size), dtype=torch.bool, device=query.device)).view(1, 1, size, size) + return causal_mask + + def forwardAttention( + self, + attention, + hidden_states: Optional[Tuple[torch.FloatTensor]], + rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None, + ): + query, key, value = self.split_heads(attention, hidden_states) + query, key = self.pos_emb(query, key, rotary_pos_emb_list) + causal_mask = self.build_mask(query) + return self.attention(attention, query, key, value, causal_mask) + def forwardQWenBlock( self, block, diff --git a/qwen/research_attention.py b/qwen/research_attention.py new file mode 100644 index 0000000..5c7d768 --- /dev/null +++ b/qwen/research_attention.py @@ -0,0 +1,73 @@ +import torch +import sys +from modelscope import snapshot_download +from transformers import AutoTokenizer +from transformers import AutoConfig + +from modeling_qwen import QWenLMHeadModel +from modeling_qwen import QwenRunner + +sys.path.append("..") +from tools import show +from tools import mem_tracker + +seed = 4321 +torch.manual_seed(seed) +torch.cuda.manual_seed_all(seed) + +model_dir = snapshot_download("qwen/Qwen-1_8B-Chat") +# model_dir = "/home/colin/.cache/modelscope/hub/qwen/Qwen-1_8B-Chat" + +config, kwargs = AutoConfig.from_pretrained( + "./", + return_unused_kwargs=True, + trust_remote_code=True, + code_revision=None, + _commit_hash=None, +) +model = QWenLMHeadModel(config) + +print(model) + +tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) +model = model.from_pretrained(model_dir).cuda() + +model = model.eval() + + +class ResearchRunner(QwenRunner): + def forwardAttention( + self, + attention, + hidden_states, + rotary_pos_emb_list=None, + ): + query, key, value = self.split_heads(attention, hidden_states) + query, key = self.pos_emb(query, key, rotary_pos_emb_list) + causal_mask = self.build_mask(query) + + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + size = query.shape[1] + qk = q @ k.transpose(-2, -1) + qk = qk[0] + prePath = "./img/" + show.DumpTensorToImage(qk, prePath + "q@k_seq_" + str(size) + "_layer_" + str(attention.index) + ".png") + return self.attention(attention, query, key, value, causal_mask) + + +runner = ResearchRunner(model) + +# 第一轮对话 +response, history, decode_tokens = runner.Chat(tokenizer, "东南亚国家日本的首都是什么市", "") +print(decode_tokens) +# <|im_start|>system +# You are a helpful assistant.<|im_end|> +# <|im_start|>user +# 东南亚国家日本的首都是什么市<|im_end|> +# <|im_start|>assistant +# 日本的首都东京。<|im_end|> +# <|endoftext|> + +if decode_tokens.split("\n")[-2] != """日本的首都东京。<|im_end|>""": + raise ()