Refine qwen to module fomater.
This commit is contained in:
parent
9d28280cb1
commit
dab1c94bc6
|
@ -1,3 +1,4 @@
|
||||||
__pycache__
|
__pycache__
|
||||||
.vscode
|
.vscode
|
||||||
*.txt
|
*.txt
|
||||||
|
temp
|
|
@ -256,59 +256,65 @@ class QwenRunner:
|
||||||
history.append((query, response))
|
history.append((query, response))
|
||||||
return response, history, decoded
|
return response, history, decoded
|
||||||
|
|
||||||
def forwardAttention(
|
def _rotate_half(self, x):
|
||||||
|
x = rearrange(x, "... (j d) -> ... j d", j=2)
|
||||||
|
x1, x2 = x.unbind(dim=-2)
|
||||||
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb(self, t, freqs):
|
||||||
|
rot_dim = freqs[0].shape[-1]
|
||||||
|
cos, sin = freqs
|
||||||
|
t_float = t.float()
|
||||||
|
t_rot, t_pass = t_float[..., :rot_dim], t_float[..., rot_dim:]
|
||||||
|
t_rot = (t_rot * cos) + (self._rotate_half(t_rot) * sin)
|
||||||
|
return torch.cat((t_rot, t_pass), dim=-1).type_as(t)
|
||||||
|
|
||||||
|
def split_heads(
|
||||||
self,
|
self,
|
||||||
attention,
|
attention,
|
||||||
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
||||||
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
|
|
||||||
):
|
):
|
||||||
def apply_rotary_pos_emb(t, freqs):
|
|
||||||
def _rotate_half(x):
|
|
||||||
x = rearrange(x, "... (j d) -> ... j d", j=2)
|
|
||||||
x1, x2 = x.unbind(dim=-2)
|
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
|
||||||
|
|
||||||
rot_dim = freqs[0].shape[-1]
|
|
||||||
cos, sin = freqs
|
|
||||||
t_float = t.float()
|
|
||||||
t_rot, t_pass = t_float[..., :rot_dim], t_float[..., rot_dim:]
|
|
||||||
t_rot = (t_rot * cos) + (_rotate_half(t_rot) * sin)
|
|
||||||
return torch.cat((t_rot, t_pass), dim=-1).type_as(t)
|
|
||||||
|
|
||||||
atten = attention
|
atten = attention
|
||||||
mixed_x_layer = atten.c_attn(hidden_states)
|
mixed_x_layer = atten.c_attn(hidden_states)
|
||||||
query, key, value = mixed_x_layer.split(atten.split_size, dim=2)
|
query, key, value = mixed_x_layer.split(atten.split_size, dim=2)
|
||||||
query = atten._split_heads(query, atten.num_heads, atten.head_dim)
|
query = atten._split_heads(query, atten.num_heads, atten.head_dim)
|
||||||
key = atten._split_heads(key, atten.num_heads, atten.head_dim)
|
key = atten._split_heads(key, atten.num_heads, atten.head_dim)
|
||||||
value = atten._split_heads(value, atten.num_heads, atten.head_dim)
|
value = atten._split_heads(value, atten.num_heads, atten.head_dim)
|
||||||
|
return query, key, value
|
||||||
|
|
||||||
|
def pos_emb(self, query, key, rotary_pos_emb_list):
|
||||||
rotary_pos_emb = rotary_pos_emb_list[0]
|
rotary_pos_emb = rotary_pos_emb_list[0]
|
||||||
rotary_pos_emb = [i[:, -query.shape[1] :, :, :] for i in rotary_pos_emb]
|
rotary_pos_emb = [i[:, -query.shape[1] :, :, :] for i in rotary_pos_emb]
|
||||||
rotary_pos_emb = (rotary_pos_emb,) * 2
|
rotary_pos_emb = (rotary_pos_emb,) * 2
|
||||||
query = apply_rotary_pos_emb(query, rotary_pos_emb[0])
|
query = self.apply_rotary_pos_emb(query, rotary_pos_emb[0])
|
||||||
key = apply_rotary_pos_emb(key, rotary_pos_emb[1])
|
key = self.apply_rotary_pos_emb(key, rotary_pos_emb[1])
|
||||||
|
return query, key
|
||||||
|
|
||||||
key_size = key.size(1)
|
def attention(self, attention, query, key, value, causal_mask):
|
||||||
causal_mask = torch.tril(torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)).view(
|
|
||||||
1, 1, key_size, key_size
|
|
||||||
)
|
|
||||||
query = query.permute(0, 2, 1, 3)
|
query = query.permute(0, 2, 1, 3)
|
||||||
key = key.permute(0, 2, 1, 3)
|
key = key.permute(0, 2, 1, 3)
|
||||||
value = value.permute(0, 2, 1, 3)
|
value = value.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
# qk = query @ key.transpose(-2, -1)
|
|
||||||
# qk = qk[0]
|
|
||||||
# prePath = "../generated/query_matmul_key/img/"
|
|
||||||
# show.DumpTensorToImage(
|
|
||||||
# qk, prePath + "q_matmul_k_sequence_" + str(key_size) + "_layer_" + str(self.index) + ".png"
|
|
||||||
# )
|
|
||||||
|
|
||||||
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=causal_mask).transpose(1, 2)
|
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=causal_mask).transpose(1, 2)
|
||||||
context_layer = atten._merge_heads(attn_output, atten.num_heads, atten.head_dim)
|
context_layer = attention._merge_heads(attn_output, attention.num_heads, attention.head_dim)
|
||||||
attn_output = atten.c_proj(context_layer)
|
attn_output = attention.c_proj(context_layer)
|
||||||
|
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
|
def build_mask(self, query):
|
||||||
|
size = query.size(1)
|
||||||
|
causal_mask = torch.tril(torch.ones((size, size), dtype=torch.bool, device=query.device)).view(1, 1, size, size)
|
||||||
|
return causal_mask
|
||||||
|
|
||||||
|
def forwardAttention(
|
||||||
|
self,
|
||||||
|
attention,
|
||||||
|
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
||||||
|
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
|
||||||
|
):
|
||||||
|
query, key, value = self.split_heads(attention, hidden_states)
|
||||||
|
query, key = self.pos_emb(query, key, rotary_pos_emb_list)
|
||||||
|
causal_mask = self.build_mask(query)
|
||||||
|
return self.attention(attention, query, key, value, causal_mask)
|
||||||
|
|
||||||
def forwardQWenBlock(
|
def forwardQWenBlock(
|
||||||
self,
|
self,
|
||||||
block,
|
block,
|
||||||
|
|
|
@ -0,0 +1,73 @@
|
||||||
|
import torch
|
||||||
|
import sys
|
||||||
|
from modelscope import snapshot_download
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from transformers import AutoConfig
|
||||||
|
|
||||||
|
from modeling_qwen import QWenLMHeadModel
|
||||||
|
from modeling_qwen import QwenRunner
|
||||||
|
|
||||||
|
sys.path.append("..")
|
||||||
|
from tools import show
|
||||||
|
from tools import mem_tracker
|
||||||
|
|
||||||
|
seed = 4321
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
model_dir = snapshot_download("qwen/Qwen-1_8B-Chat")
|
||||||
|
# model_dir = "/home/colin/.cache/modelscope/hub/qwen/Qwen-1_8B-Chat"
|
||||||
|
|
||||||
|
config, kwargs = AutoConfig.from_pretrained(
|
||||||
|
"./",
|
||||||
|
return_unused_kwargs=True,
|
||||||
|
trust_remote_code=True,
|
||||||
|
code_revision=None,
|
||||||
|
_commit_hash=None,
|
||||||
|
)
|
||||||
|
model = QWenLMHeadModel(config)
|
||||||
|
|
||||||
|
print(model)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
||||||
|
model = model.from_pretrained(model_dir).cuda()
|
||||||
|
|
||||||
|
model = model.eval()
|
||||||
|
|
||||||
|
|
||||||
|
class ResearchRunner(QwenRunner):
|
||||||
|
def forwardAttention(
|
||||||
|
self,
|
||||||
|
attention,
|
||||||
|
hidden_states,
|
||||||
|
rotary_pos_emb_list=None,
|
||||||
|
):
|
||||||
|
query, key, value = self.split_heads(attention, hidden_states)
|
||||||
|
query, key = self.pos_emb(query, key, rotary_pos_emb_list)
|
||||||
|
causal_mask = self.build_mask(query)
|
||||||
|
|
||||||
|
q = query.permute(0, 2, 1, 3)
|
||||||
|
k = key.permute(0, 2, 1, 3)
|
||||||
|
size = query.shape[1]
|
||||||
|
qk = q @ k.transpose(-2, -1)
|
||||||
|
qk = qk[0]
|
||||||
|
prePath = "./temp/"
|
||||||
|
show.DumpTensorToImage(qk, prePath + "q@k_seq_" + str(size) + "_layer_" + str(attention.index) + ".png")
|
||||||
|
return self.attention(attention, query, key, value, causal_mask)
|
||||||
|
|
||||||
|
|
||||||
|
runner = ResearchRunner(model)
|
||||||
|
|
||||||
|
# 第一轮对话
|
||||||
|
response, history, decode_tokens = runner.Chat(tokenizer, "东南亚国家日本的首都是什么市", "")
|
||||||
|
print(decode_tokens)
|
||||||
|
# <|im_start|>system
|
||||||
|
# You are a helpful assistant.<|im_end|>
|
||||||
|
# <|im_start|>user
|
||||||
|
# 东南亚国家日本的首都是什么市<|im_end|>
|
||||||
|
# <|im_start|>assistant
|
||||||
|
# 日本的首都东京。<|im_end|>
|
||||||
|
# <|endoftext|>
|
||||||
|
|
||||||
|
if decode_tokens.split("\n")[-2] != """日本的首都东京。<|im_end|>""":
|
||||||
|
raise ()
|
Loading…
Reference in New Issue