Refine research_attention.

This commit is contained in:
Colin 2024-01-22 20:57:27 +08:00
parent 5dbac40925
commit 1811b9611a
5 changed files with 78 additions and 55 deletions

View File

@ -47,9 +47,9 @@ def get_ltor_masks_and_position_ids(
att_mask_batch = micro_batch_size
else:
att_mask_batch = 1
attention_mask = torch.tril(
torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)
).view(att_mask_batch, 1, seq_length, seq_length)
attention_mask = torch.tril(torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)).view(
att_mask_batch, 1, seq_length, seq_length
)
# Loss mask.
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
@ -66,7 +66,6 @@ def get_ltor_masks_and_position_ids(
if reset_position_ids or reset_attention_mask:
# Loop through the batches:
for b in range(micro_batch_size):
# Find indecies where EOD token is.
eod_index = position_ids[b, data[b] == eod_token]
# Detach indecies from positions if going to modify positions.
@ -105,13 +104,14 @@ def get_batch(context_tokens: torch.LongTensor, eod_id: int):
)
return tokens, attention_mask, position_ids
def make_context(
tokenizer: PreTrainedTokenizer,
query: str,
query_assistant: str = "",
history: List[Tuple[str, str]] = None,
system: str = "",
max_window_size: int = 6144
max_window_size: int = 6144,
):
if history is None:
history = []
@ -122,32 +122,26 @@ def make_context(
nl_tokens = tokenizer.encode("\n")
def _tokenize_str(role, content):
return f"{role}\n{content}", tokenizer.encode(
role, allowed_special=set()
) + nl_tokens + tokenizer.encode(content, allowed_special=set())
return f"{role}\n{content}", tokenizer.encode(role, allowed_special=set()) + nl_tokens + tokenizer.encode(
content, allowed_special=set()
)
system_text, system_tokens_part = _tokenize_str("system", system)
system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
assistant_tokens = tokenizer.encode(query_assistant, allowed_special=set())
assistant_tokens = tokenizer.encode(query_assistant, allowed_special=set())
raw_text = ""
context_tokens = []
for turn_query, turn_response in reversed(history):
query_text, query_tokens_part = _tokenize_str("user", turn_query)
query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
response_text, response_tokens_part = _tokenize_str(
"assistant", turn_response
)
response_text, response_tokens_part = _tokenize_str("assistant", turn_response)
response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
prev_chat = (
f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
)
prev_chat = f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
current_context_size = (
len(system_tokens) + len(next_context_tokens) + len(context_tokens)
)
current_context_size = len(system_tokens) + len(next_context_tokens) + len(context_tokens)
if current_context_size < max_window_size:
context_tokens = next_context_tokens + context_tokens
raw_text = prev_chat + raw_text
@ -171,12 +165,13 @@ def make_context(
return raw_text, context_tokens
def decode_tokens(
tokens: Union[torch.LongTensor, TokensType],
tokenizer: PreTrainedTokenizer,
raw_text_len: int,
context_length: int,
errors: str="replace",
raw_text_len: int = 0,
context_length: int = 0,
errors: str = "replace",
) -> str:
if torch.is_tensor(tokens):
tokens = tokens.cpu().numpy().tolist()
@ -211,42 +206,26 @@ class StopWordsLogitsProcessor(LogitsProcessor):
"""
def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int):
if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0:
raise ValueError(
f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}."
)
raise ValueError(f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}.")
if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids):
raise ValueError(
f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}."
)
raise ValueError(f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}.")
if any(
any(
(not isinstance(token_id, (int, np.integer)) or token_id < 0)
for token_id in stop_word_ids
)
any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in stop_word_ids)
for stop_word_ids in stop_words_ids
):
raise ValueError(
f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}."
)
self.stop_words_ids = list(
filter(
lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids
)
)
self.stop_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids))
self.eos_token_id = eos_token_id
for stop_token_seq in self.stop_words_ids:
assert (
len(stop_token_seq) > 0
), "Stop words token sequences {} cannot have an empty list".format(
assert len(stop_token_seq) > 0, "Stop words token sequences {} cannot have an empty list".format(
stop_words_ids
)
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
stopped_samples = self._calc_stopped_samples(input_ids)
for i, should_stop in enumerate(stopped_samples):
if should_stop:

View File

@ -10,6 +10,11 @@ from modeling_qwen import QwenRunner
import torch.nn.functional as F
from qwen_generation_utils import (
make_context,
decode_tokens,
)
sys.path.append("..")
from tools import show
@ -32,10 +37,39 @@ model = QWenLMHeadModel(config)
print(model)
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
model = model.from_pretrained(model_dir).cuda()
model = model.from_pretrained(model_dir)
model = model.eval()
def Dump_tokens_list(model):
tokens = []
for token in range(config.eos_token_id):
decoded, response, end_reason = decode_tokens(
[token],
tokenizer,
raw_text_len=0,
context_length=0,
errors="replace",
)
tokens.append(str(token) + ": " + decoded)
show.DumpListToFile(tokens, "./temp/qwen_token_list.txt")
# Dump_tokens_list(model)
def Dump_lm_head_weight(model):
weight = model.lm_head.weight.cpu() # [151936,2048,]
weight = weight.reshape(64, -1, 2048)
for i in range(64):
sub = weight[i].reshape(-1, 64, 32)
show.DumpTensorToImage(sub, "./temp/lm_head_" + str(i) + "_2374_2048.png")
# Dump_lm_head_weight(model)
class ResearchRunner(QwenRunner):
def attention(self, attention, query, key, value, causal_mask):
query = query.permute(0, 2, 1, 3)
@ -50,8 +84,8 @@ class ResearchRunner(QwenRunner):
attn_mask.masked_fill_(causal_mask.logical_not(), float(0))
qk = attn_weight * attn_mask
qk = qk[0]
prePath = "./temp/"
show.DumpTensorToImage(qk, prePath + "q@k_seq_" + str(size) + "_layer_" + str(attention.index) + ".png")
prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(attention.index) + ".png"
show.DumpTensorToImage(qk, prePath, GridValue=255)
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=causal_mask).transpose(1, 2)
context_layer = attention._merge_heads(attn_output, attention.num_heads, attention.head_dim)
@ -71,6 +105,6 @@ print(decode_tokens)
# <|im_start|>assistant
# 日本的首都东京。<|im_end|>
# <|endoftext|>
# 日本的首都是东京。<|im_end|>
if decode_tokens.split("\n")[-2] != """日本的首都东京。<|im_end|>""":
raise ()

View File

@ -1,8 +1,16 @@
import torch
import torch.nn.functional as F
x = torch.tensor([[1, 2], [3, 4]])
x1 = torch.tensor([[1, 2]], dtype=float)
x2 = torch.tensor([[5, 6], [7, 8]], dtype=float)
y = x1 @ x2 # torch.matmul(x1 , x2)
x_inverse = torch.inverse(x2, out=None)
y_inverse = y @ x_inverse
y_inverse = y_inverse.permute(1, 0)
x = torch.tensor([[1, 2], [3, 4]], dtype=float)
print(x)
print("x.tile((2)) -> ", x.tile((2)).shape)
print(x.tile((2)))
@ -55,4 +63,3 @@ z = torch.ones([1, 32, 6, 128])
att = torch.matmul(x, y)
mm = torch.matmul(att, z)
print(mm.shape)

View File

@ -8,23 +8,24 @@ import numpy as np
import os
def DumpTensorToImage(tensor, name, forceSquare=True, scale=1.0, AutoContrast=True):
def DumpTensorToImage(tensor, name, forceSquare=True, scale=1.0, AutoContrast=True, GridValue=0):
if len(tensor.shape) != 2 and len(tensor.shape) != 1 and len(tensor.shape) != 3:
raise ("Error input dims")
if len(tensor.shape) == 3:
channel = tensor.shape[0]
x = math.ceil((channel) ** 0.5)
tensor = F.pad(tensor, (0, 1, 0, 1, 0, x * x - channel), mode="constant", value=0)
tensor = F.pad(tensor, (0, 0, 0, 0, 0, x * x - channel), mode="constant", value=0)
calc = tensor.reshape((x * x, tensor.shape[1] * tensor.shape[2]))
if AutoContrast:
calc = tensor.reshape((x * x, tensor.shape[1] * tensor.shape[2]))
tensormax = calc.max(1)[0]
tensormin = calc.min(1)[0]
calc = calc.transpose(1, 0)
calc = ((calc - tensormin) / (tensormax - tensormin)) * 255
calc = calc.transpose(1, 0)
tensor = calc.reshape((x, x, tensor.shape[1], tensor.shape[2]))
tensor = tensor.permute((0, 2, 1, 3))
calc = calc.reshape((x, x, tensor.shape[1], tensor.shape[2]))
calc = F.pad(calc, (0, 1, 0, 1, 0, 0), mode="constant", value=GridValue)
tensor = calc.permute((0, 2, 1, 3))
tensor = tensor.reshape((x * tensor.shape[1], x * tensor.shape[3]))
DumpTensorToImage(tensor, name, forceSquare=False, scale=scale, AutoContrast=False)
return

View File

@ -12,7 +12,9 @@ show.DumpTensorToImage(radata, "test.png")
radata = torch.randn(3, 127, 127)
show.DumpTensorToImage(radata, "test.png")
show.DumpTensorToImage(radata, "test1.png", AutoContrast=True, GridValue=0)
show.DumpTensorToImage(radata, "test2.png", AutoContrast=True, GridValue=255)
show.DumpTensorToImage(radata, "test3.png", AutoContrast=False, GridValue=0)
radata = torch.randn(127, 127)