Refine research_attention.
This commit is contained in:
parent
5dbac40925
commit
1811b9611a
|
@ -47,9 +47,9 @@ def get_ltor_masks_and_position_ids(
|
|||
att_mask_batch = micro_batch_size
|
||||
else:
|
||||
att_mask_batch = 1
|
||||
attention_mask = torch.tril(
|
||||
torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)
|
||||
).view(att_mask_batch, 1, seq_length, seq_length)
|
||||
attention_mask = torch.tril(torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)).view(
|
||||
att_mask_batch, 1, seq_length, seq_length
|
||||
)
|
||||
|
||||
# Loss mask.
|
||||
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
|
||||
|
@ -66,7 +66,6 @@ def get_ltor_masks_and_position_ids(
|
|||
if reset_position_ids or reset_attention_mask:
|
||||
# Loop through the batches:
|
||||
for b in range(micro_batch_size):
|
||||
|
||||
# Find indecies where EOD token is.
|
||||
eod_index = position_ids[b, data[b] == eod_token]
|
||||
# Detach indecies from positions if going to modify positions.
|
||||
|
@ -105,13 +104,14 @@ def get_batch(context_tokens: torch.LongTensor, eod_id: int):
|
|||
)
|
||||
return tokens, attention_mask, position_ids
|
||||
|
||||
|
||||
def make_context(
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
query: str,
|
||||
query_assistant: str = "",
|
||||
history: List[Tuple[str, str]] = None,
|
||||
system: str = "",
|
||||
max_window_size: int = 6144
|
||||
max_window_size: int = 6144,
|
||||
):
|
||||
if history is None:
|
||||
history = []
|
||||
|
@ -122,32 +122,26 @@ def make_context(
|
|||
nl_tokens = tokenizer.encode("\n")
|
||||
|
||||
def _tokenize_str(role, content):
|
||||
return f"{role}\n{content}", tokenizer.encode(
|
||||
role, allowed_special=set()
|
||||
) + nl_tokens + tokenizer.encode(content, allowed_special=set())
|
||||
return f"{role}\n{content}", tokenizer.encode(role, allowed_special=set()) + nl_tokens + tokenizer.encode(
|
||||
content, allowed_special=set()
|
||||
)
|
||||
|
||||
system_text, system_tokens_part = _tokenize_str("system", system)
|
||||
system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
|
||||
assistant_tokens = tokenizer.encode(query_assistant, allowed_special=set())
|
||||
assistant_tokens = tokenizer.encode(query_assistant, allowed_special=set())
|
||||
raw_text = ""
|
||||
context_tokens = []
|
||||
|
||||
for turn_query, turn_response in reversed(history):
|
||||
query_text, query_tokens_part = _tokenize_str("user", turn_query)
|
||||
query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
|
||||
response_text, response_tokens_part = _tokenize_str(
|
||||
"assistant", turn_response
|
||||
)
|
||||
response_text, response_tokens_part = _tokenize_str("assistant", turn_response)
|
||||
response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
|
||||
|
||||
next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
|
||||
prev_chat = (
|
||||
f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
|
||||
)
|
||||
prev_chat = f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
|
||||
|
||||
current_context_size = (
|
||||
len(system_tokens) + len(next_context_tokens) + len(context_tokens)
|
||||
)
|
||||
current_context_size = len(system_tokens) + len(next_context_tokens) + len(context_tokens)
|
||||
if current_context_size < max_window_size:
|
||||
context_tokens = next_context_tokens + context_tokens
|
||||
raw_text = prev_chat + raw_text
|
||||
|
@ -171,12 +165,13 @@ def make_context(
|
|||
|
||||
return raw_text, context_tokens
|
||||
|
||||
|
||||
def decode_tokens(
|
||||
tokens: Union[torch.LongTensor, TokensType],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
raw_text_len: int,
|
||||
context_length: int,
|
||||
errors: str="replace",
|
||||
raw_text_len: int = 0,
|
||||
context_length: int = 0,
|
||||
errors: str = "replace",
|
||||
) -> str:
|
||||
if torch.is_tensor(tokens):
|
||||
tokens = tokens.cpu().numpy().tolist()
|
||||
|
@ -211,42 +206,26 @@ class StopWordsLogitsProcessor(LogitsProcessor):
|
|||
"""
|
||||
|
||||
def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int):
|
||||
|
||||
if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0:
|
||||
raise ValueError(
|
||||
f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}."
|
||||
)
|
||||
raise ValueError(f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}.")
|
||||
if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids):
|
||||
raise ValueError(
|
||||
f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}."
|
||||
)
|
||||
raise ValueError(f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}.")
|
||||
if any(
|
||||
any(
|
||||
(not isinstance(token_id, (int, np.integer)) or token_id < 0)
|
||||
for token_id in stop_word_ids
|
||||
)
|
||||
any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in stop_word_ids)
|
||||
for stop_word_ids in stop_words_ids
|
||||
):
|
||||
raise ValueError(
|
||||
f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}."
|
||||
)
|
||||
|
||||
self.stop_words_ids = list(
|
||||
filter(
|
||||
lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids
|
||||
)
|
||||
)
|
||||
self.stop_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids))
|
||||
self.eos_token_id = eos_token_id
|
||||
for stop_token_seq in self.stop_words_ids:
|
||||
assert (
|
||||
len(stop_token_seq) > 0
|
||||
), "Stop words token sequences {} cannot have an empty list".format(
|
||||
assert len(stop_token_seq) > 0, "Stop words token sequences {} cannot have an empty list".format(
|
||||
stop_words_ids
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
stopped_samples = self._calc_stopped_samples(input_ids)
|
||||
for i, should_stop in enumerate(stopped_samples):
|
||||
if should_stop:
|
||||
|
|
|
@ -10,6 +10,11 @@ from modeling_qwen import QwenRunner
|
|||
|
||||
import torch.nn.functional as F
|
||||
|
||||
from qwen_generation_utils import (
|
||||
make_context,
|
||||
decode_tokens,
|
||||
)
|
||||
|
||||
sys.path.append("..")
|
||||
from tools import show
|
||||
|
||||
|
@ -32,10 +37,39 @@ model = QWenLMHeadModel(config)
|
|||
print(model)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
||||
model = model.from_pretrained(model_dir).cuda()
|
||||
|
||||
model = model.from_pretrained(model_dir)
|
||||
model = model.eval()
|
||||
|
||||
|
||||
def Dump_tokens_list(model):
|
||||
tokens = []
|
||||
for token in range(config.eos_token_id):
|
||||
decoded, response, end_reason = decode_tokens(
|
||||
[token],
|
||||
tokenizer,
|
||||
raw_text_len=0,
|
||||
context_length=0,
|
||||
errors="replace",
|
||||
)
|
||||
tokens.append(str(token) + ": " + decoded)
|
||||
show.DumpListToFile(tokens, "./temp/qwen_token_list.txt")
|
||||
|
||||
|
||||
# Dump_tokens_list(model)
|
||||
|
||||
|
||||
def Dump_lm_head_weight(model):
|
||||
weight = model.lm_head.weight.cpu() # [151936,2048,]
|
||||
weight = weight.reshape(64, -1, 2048)
|
||||
for i in range(64):
|
||||
sub = weight[i].reshape(-1, 64, 32)
|
||||
show.DumpTensorToImage(sub, "./temp/lm_head_" + str(i) + "_2374_2048.png")
|
||||
|
||||
|
||||
# Dump_lm_head_weight(model)
|
||||
|
||||
|
||||
class ResearchRunner(QwenRunner):
|
||||
def attention(self, attention, query, key, value, causal_mask):
|
||||
query = query.permute(0, 2, 1, 3)
|
||||
|
@ -50,8 +84,8 @@ class ResearchRunner(QwenRunner):
|
|||
attn_mask.masked_fill_(causal_mask.logical_not(), float(0))
|
||||
qk = attn_weight * attn_mask
|
||||
qk = qk[0]
|
||||
prePath = "./temp/"
|
||||
show.DumpTensorToImage(qk, prePath + "q@k_seq_" + str(size) + "_layer_" + str(attention.index) + ".png")
|
||||
prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(attention.index) + ".png"
|
||||
show.DumpTensorToImage(qk, prePath, GridValue=255)
|
||||
|
||||
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=causal_mask).transpose(1, 2)
|
||||
context_layer = attention._merge_heads(attn_output, attention.num_heads, attention.head_dim)
|
||||
|
@ -71,6 +105,6 @@ print(decode_tokens)
|
|||
# <|im_start|>assistant
|
||||
# 日本的首都东京。<|im_end|>
|
||||
# <|endoftext|>
|
||||
# 日本的首都是东京。<|im_end|>
|
||||
|
||||
if decode_tokens.split("\n")[-2] != """日本的首都东京。<|im_end|>""":
|
||||
raise ()
|
||||
|
|
|
@ -1,8 +1,16 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
x = torch.tensor([[1, 2], [3, 4]])
|
||||
x1 = torch.tensor([[1, 2]], dtype=float)
|
||||
x2 = torch.tensor([[5, 6], [7, 8]], dtype=float)
|
||||
|
||||
y = x1 @ x2 # torch.matmul(x1 , x2)
|
||||
x_inverse = torch.inverse(x2, out=None)
|
||||
y_inverse = y @ x_inverse
|
||||
y_inverse = y_inverse.permute(1, 0)
|
||||
|
||||
|
||||
x = torch.tensor([[1, 2], [3, 4]], dtype=float)
|
||||
print(x)
|
||||
print("x.tile((2)) -> ", x.tile((2)).shape)
|
||||
print(x.tile((2)))
|
||||
|
@ -55,4 +63,3 @@ z = torch.ones([1, 32, 6, 128])
|
|||
att = torch.matmul(x, y)
|
||||
mm = torch.matmul(att, z)
|
||||
print(mm.shape)
|
||||
|
||||
|
|
|
@ -8,23 +8,24 @@ import numpy as np
|
|||
import os
|
||||
|
||||
|
||||
def DumpTensorToImage(tensor, name, forceSquare=True, scale=1.0, AutoContrast=True):
|
||||
def DumpTensorToImage(tensor, name, forceSquare=True, scale=1.0, AutoContrast=True, GridValue=0):
|
||||
if len(tensor.shape) != 2 and len(tensor.shape) != 1 and len(tensor.shape) != 3:
|
||||
raise ("Error input dims")
|
||||
|
||||
if len(tensor.shape) == 3:
|
||||
channel = tensor.shape[0]
|
||||
x = math.ceil((channel) ** 0.5)
|
||||
tensor = F.pad(tensor, (0, 1, 0, 1, 0, x * x - channel), mode="constant", value=0)
|
||||
tensor = F.pad(tensor, (0, 0, 0, 0, 0, x * x - channel), mode="constant", value=0)
|
||||
calc = tensor.reshape((x * x, tensor.shape[1] * tensor.shape[2]))
|
||||
if AutoContrast:
|
||||
calc = tensor.reshape((x * x, tensor.shape[1] * tensor.shape[2]))
|
||||
tensormax = calc.max(1)[0]
|
||||
tensormin = calc.min(1)[0]
|
||||
calc = calc.transpose(1, 0)
|
||||
calc = ((calc - tensormin) / (tensormax - tensormin)) * 255
|
||||
calc = calc.transpose(1, 0)
|
||||
tensor = calc.reshape((x, x, tensor.shape[1], tensor.shape[2]))
|
||||
tensor = tensor.permute((0, 2, 1, 3))
|
||||
calc = calc.reshape((x, x, tensor.shape[1], tensor.shape[2]))
|
||||
calc = F.pad(calc, (0, 1, 0, 1, 0, 0), mode="constant", value=GridValue)
|
||||
tensor = calc.permute((0, 2, 1, 3))
|
||||
tensor = tensor.reshape((x * tensor.shape[1], x * tensor.shape[3]))
|
||||
DumpTensorToImage(tensor, name, forceSquare=False, scale=scale, AutoContrast=False)
|
||||
return
|
||||
|
|
|
@ -12,7 +12,9 @@ show.DumpTensorToImage(radata, "test.png")
|
|||
|
||||
|
||||
radata = torch.randn(3, 127, 127)
|
||||
show.DumpTensorToImage(radata, "test.png")
|
||||
show.DumpTensorToImage(radata, "test1.png", AutoContrast=True, GridValue=0)
|
||||
show.DumpTensorToImage(radata, "test2.png", AutoContrast=True, GridValue=255)
|
||||
show.DumpTensorToImage(radata, "test3.png", AutoContrast=False, GridValue=0)
|
||||
|
||||
|
||||
radata = torch.randn(127, 127)
|
||||
|
|
Loading…
Reference in New Issue