Refine research_attention.

This commit is contained in:
Colin 2024-01-22 20:57:27 +08:00
parent 5dbac40925
commit 1811b9611a
5 changed files with 78 additions and 55 deletions

View File

@ -47,9 +47,9 @@ def get_ltor_masks_and_position_ids(
att_mask_batch = micro_batch_size att_mask_batch = micro_batch_size
else: else:
att_mask_batch = 1 att_mask_batch = 1
attention_mask = torch.tril( attention_mask = torch.tril(torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)).view(
torch.ones((att_mask_batch, seq_length, seq_length), device=data.device) att_mask_batch, 1, seq_length, seq_length
).view(att_mask_batch, 1, seq_length, seq_length) )
# Loss mask. # Loss mask.
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device) loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
@ -66,7 +66,6 @@ def get_ltor_masks_and_position_ids(
if reset_position_ids or reset_attention_mask: if reset_position_ids or reset_attention_mask:
# Loop through the batches: # Loop through the batches:
for b in range(micro_batch_size): for b in range(micro_batch_size):
# Find indecies where EOD token is. # Find indecies where EOD token is.
eod_index = position_ids[b, data[b] == eod_token] eod_index = position_ids[b, data[b] == eod_token]
# Detach indecies from positions if going to modify positions. # Detach indecies from positions if going to modify positions.
@ -105,13 +104,14 @@ def get_batch(context_tokens: torch.LongTensor, eod_id: int):
) )
return tokens, attention_mask, position_ids return tokens, attention_mask, position_ids
def make_context( def make_context(
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
query: str, query: str,
query_assistant: str = "", query_assistant: str = "",
history: List[Tuple[str, str]] = None, history: List[Tuple[str, str]] = None,
system: str = "", system: str = "",
max_window_size: int = 6144 max_window_size: int = 6144,
): ):
if history is None: if history is None:
history = [] history = []
@ -122,32 +122,26 @@ def make_context(
nl_tokens = tokenizer.encode("\n") nl_tokens = tokenizer.encode("\n")
def _tokenize_str(role, content): def _tokenize_str(role, content):
return f"{role}\n{content}", tokenizer.encode( return f"{role}\n{content}", tokenizer.encode(role, allowed_special=set()) + nl_tokens + tokenizer.encode(
role, allowed_special=set() content, allowed_special=set()
) + nl_tokens + tokenizer.encode(content, allowed_special=set()) )
system_text, system_tokens_part = _tokenize_str("system", system) system_text, system_tokens_part = _tokenize_str("system", system)
system_tokens = im_start_tokens + system_tokens_part + im_end_tokens system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
assistant_tokens = tokenizer.encode(query_assistant, allowed_special=set()) assistant_tokens = tokenizer.encode(query_assistant, allowed_special=set())
raw_text = "" raw_text = ""
context_tokens = [] context_tokens = []
for turn_query, turn_response in reversed(history): for turn_query, turn_response in reversed(history):
query_text, query_tokens_part = _tokenize_str("user", turn_query) query_text, query_tokens_part = _tokenize_str("user", turn_query)
query_tokens = im_start_tokens + query_tokens_part + im_end_tokens query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
response_text, response_tokens_part = _tokenize_str( response_text, response_tokens_part = _tokenize_str("assistant", turn_response)
"assistant", turn_response
)
response_tokens = im_start_tokens + response_tokens_part + im_end_tokens response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
prev_chat = ( prev_chat = f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
)
current_context_size = ( current_context_size = len(system_tokens) + len(next_context_tokens) + len(context_tokens)
len(system_tokens) + len(next_context_tokens) + len(context_tokens)
)
if current_context_size < max_window_size: if current_context_size < max_window_size:
context_tokens = next_context_tokens + context_tokens context_tokens = next_context_tokens + context_tokens
raw_text = prev_chat + raw_text raw_text = prev_chat + raw_text
@ -171,12 +165,13 @@ def make_context(
return raw_text, context_tokens return raw_text, context_tokens
def decode_tokens( def decode_tokens(
tokens: Union[torch.LongTensor, TokensType], tokens: Union[torch.LongTensor, TokensType],
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
raw_text_len: int, raw_text_len: int = 0,
context_length: int, context_length: int = 0,
errors: str="replace", errors: str = "replace",
) -> str: ) -> str:
if torch.is_tensor(tokens): if torch.is_tensor(tokens):
tokens = tokens.cpu().numpy().tolist() tokens = tokens.cpu().numpy().tolist()
@ -211,42 +206,26 @@ class StopWordsLogitsProcessor(LogitsProcessor):
""" """
def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int): def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int):
if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0: if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0:
raise ValueError( raise ValueError(f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}.")
f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}."
)
if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids): if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids):
raise ValueError( raise ValueError(f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}.")
f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}."
)
if any( if any(
any( any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in stop_word_ids)
(not isinstance(token_id, (int, np.integer)) or token_id < 0)
for token_id in stop_word_ids
)
for stop_word_ids in stop_words_ids for stop_word_ids in stop_words_ids
): ):
raise ValueError( raise ValueError(
f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}." f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}."
) )
self.stop_words_ids = list( self.stop_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids))
filter(
lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids
)
)
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
for stop_token_seq in self.stop_words_ids: for stop_token_seq in self.stop_words_ids:
assert ( assert len(stop_token_seq) > 0, "Stop words token sequences {} cannot have an empty list".format(
len(stop_token_seq) > 0
), "Stop words token sequences {} cannot have an empty list".format(
stop_words_ids stop_words_ids
) )
def __call__( def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
stopped_samples = self._calc_stopped_samples(input_ids) stopped_samples = self._calc_stopped_samples(input_ids)
for i, should_stop in enumerate(stopped_samples): for i, should_stop in enumerate(stopped_samples):
if should_stop: if should_stop:

View File

@ -10,6 +10,11 @@ from modeling_qwen import QwenRunner
import torch.nn.functional as F import torch.nn.functional as F
from qwen_generation_utils import (
make_context,
decode_tokens,
)
sys.path.append("..") sys.path.append("..")
from tools import show from tools import show
@ -32,10 +37,39 @@ model = QWenLMHeadModel(config)
print(model) print(model)
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
model = model.from_pretrained(model_dir).cuda()
model = model.from_pretrained(model_dir)
model = model.eval() model = model.eval()
def Dump_tokens_list(model):
tokens = []
for token in range(config.eos_token_id):
decoded, response, end_reason = decode_tokens(
[token],
tokenizer,
raw_text_len=0,
context_length=0,
errors="replace",
)
tokens.append(str(token) + ": " + decoded)
show.DumpListToFile(tokens, "./temp/qwen_token_list.txt")
# Dump_tokens_list(model)
def Dump_lm_head_weight(model):
weight = model.lm_head.weight.cpu() # [151936,2048,]
weight = weight.reshape(64, -1, 2048)
for i in range(64):
sub = weight[i].reshape(-1, 64, 32)
show.DumpTensorToImage(sub, "./temp/lm_head_" + str(i) + "_2374_2048.png")
# Dump_lm_head_weight(model)
class ResearchRunner(QwenRunner): class ResearchRunner(QwenRunner):
def attention(self, attention, query, key, value, causal_mask): def attention(self, attention, query, key, value, causal_mask):
query = query.permute(0, 2, 1, 3) query = query.permute(0, 2, 1, 3)
@ -50,8 +84,8 @@ class ResearchRunner(QwenRunner):
attn_mask.masked_fill_(causal_mask.logical_not(), float(0)) attn_mask.masked_fill_(causal_mask.logical_not(), float(0))
qk = attn_weight * attn_mask qk = attn_weight * attn_mask
qk = qk[0] qk = qk[0]
prePath = "./temp/" prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(attention.index) + ".png"
show.DumpTensorToImage(qk, prePath + "q@k_seq_" + str(size) + "_layer_" + str(attention.index) + ".png") show.DumpTensorToImage(qk, prePath, GridValue=255)
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=causal_mask).transpose(1, 2) attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=causal_mask).transpose(1, 2)
context_layer = attention._merge_heads(attn_output, attention.num_heads, attention.head_dim) context_layer = attention._merge_heads(attn_output, attention.num_heads, attention.head_dim)
@ -71,6 +105,6 @@ print(decode_tokens)
# <|im_start|>assistant # <|im_start|>assistant
# 日本的首都东京。<|im_end|> # 日本的首都东京。<|im_end|>
# <|endoftext|> # <|endoftext|>
# 日本的首都是东京。<|im_end|>
if decode_tokens.split("\n")[-2] != """日本的首都东京。<|im_end|>""": if decode_tokens.split("\n")[-2] != """日本的首都东京。<|im_end|>""":
raise () raise ()

View File

@ -1,8 +1,16 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
x = torch.tensor([[1, 2], [3, 4]]) x1 = torch.tensor([[1, 2]], dtype=float)
x2 = torch.tensor([[5, 6], [7, 8]], dtype=float)
y = x1 @ x2 # torch.matmul(x1 , x2)
x_inverse = torch.inverse(x2, out=None)
y_inverse = y @ x_inverse
y_inverse = y_inverse.permute(1, 0)
x = torch.tensor([[1, 2], [3, 4]], dtype=float)
print(x) print(x)
print("x.tile((2)) -> ", x.tile((2)).shape) print("x.tile((2)) -> ", x.tile((2)).shape)
print(x.tile((2))) print(x.tile((2)))
@ -55,4 +63,3 @@ z = torch.ones([1, 32, 6, 128])
att = torch.matmul(x, y) att = torch.matmul(x, y)
mm = torch.matmul(att, z) mm = torch.matmul(att, z)
print(mm.shape) print(mm.shape)

View File

@ -8,23 +8,24 @@ import numpy as np
import os import os
def DumpTensorToImage(tensor, name, forceSquare=True, scale=1.0, AutoContrast=True): def DumpTensorToImage(tensor, name, forceSquare=True, scale=1.0, AutoContrast=True, GridValue=0):
if len(tensor.shape) != 2 and len(tensor.shape) != 1 and len(tensor.shape) != 3: if len(tensor.shape) != 2 and len(tensor.shape) != 1 and len(tensor.shape) != 3:
raise ("Error input dims") raise ("Error input dims")
if len(tensor.shape) == 3: if len(tensor.shape) == 3:
channel = tensor.shape[0] channel = tensor.shape[0]
x = math.ceil((channel) ** 0.5) x = math.ceil((channel) ** 0.5)
tensor = F.pad(tensor, (0, 1, 0, 1, 0, x * x - channel), mode="constant", value=0) tensor = F.pad(tensor, (0, 0, 0, 0, 0, x * x - channel), mode="constant", value=0)
calc = tensor.reshape((x * x, tensor.shape[1] * tensor.shape[2]))
if AutoContrast: if AutoContrast:
calc = tensor.reshape((x * x, tensor.shape[1] * tensor.shape[2]))
tensormax = calc.max(1)[0] tensormax = calc.max(1)[0]
tensormin = calc.min(1)[0] tensormin = calc.min(1)[0]
calc = calc.transpose(1, 0) calc = calc.transpose(1, 0)
calc = ((calc - tensormin) / (tensormax - tensormin)) * 255 calc = ((calc - tensormin) / (tensormax - tensormin)) * 255
calc = calc.transpose(1, 0) calc = calc.transpose(1, 0)
tensor = calc.reshape((x, x, tensor.shape[1], tensor.shape[2])) calc = calc.reshape((x, x, tensor.shape[1], tensor.shape[2]))
tensor = tensor.permute((0, 2, 1, 3)) calc = F.pad(calc, (0, 1, 0, 1, 0, 0), mode="constant", value=GridValue)
tensor = calc.permute((0, 2, 1, 3))
tensor = tensor.reshape((x * tensor.shape[1], x * tensor.shape[3])) tensor = tensor.reshape((x * tensor.shape[1], x * tensor.shape[3]))
DumpTensorToImage(tensor, name, forceSquare=False, scale=scale, AutoContrast=False) DumpTensorToImage(tensor, name, forceSquare=False, scale=scale, AutoContrast=False)
return return

View File

@ -12,7 +12,9 @@ show.DumpTensorToImage(radata, "test.png")
radata = torch.randn(3, 127, 127) radata = torch.randn(3, 127, 127)
show.DumpTensorToImage(radata, "test.png") show.DumpTensorToImage(radata, "test1.png", AutoContrast=True, GridValue=0)
show.DumpTensorToImage(radata, "test2.png", AutoContrast=True, GridValue=255)
show.DumpTensorToImage(radata, "test3.png", AutoContrast=False, GridValue=0)
radata = torch.randn(127, 127) radata = torch.randn(127, 127)