Refine research_attention.
This commit is contained in:
parent
5dbac40925
commit
1811b9611a
|
@ -47,9 +47,9 @@ def get_ltor_masks_and_position_ids(
|
||||||
att_mask_batch = micro_batch_size
|
att_mask_batch = micro_batch_size
|
||||||
else:
|
else:
|
||||||
att_mask_batch = 1
|
att_mask_batch = 1
|
||||||
attention_mask = torch.tril(
|
attention_mask = torch.tril(torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)).view(
|
||||||
torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)
|
att_mask_batch, 1, seq_length, seq_length
|
||||||
).view(att_mask_batch, 1, seq_length, seq_length)
|
)
|
||||||
|
|
||||||
# Loss mask.
|
# Loss mask.
|
||||||
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
|
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
|
||||||
|
@ -66,7 +66,6 @@ def get_ltor_masks_and_position_ids(
|
||||||
if reset_position_ids or reset_attention_mask:
|
if reset_position_ids or reset_attention_mask:
|
||||||
# Loop through the batches:
|
# Loop through the batches:
|
||||||
for b in range(micro_batch_size):
|
for b in range(micro_batch_size):
|
||||||
|
|
||||||
# Find indecies where EOD token is.
|
# Find indecies where EOD token is.
|
||||||
eod_index = position_ids[b, data[b] == eod_token]
|
eod_index = position_ids[b, data[b] == eod_token]
|
||||||
# Detach indecies from positions if going to modify positions.
|
# Detach indecies from positions if going to modify positions.
|
||||||
|
@ -105,13 +104,14 @@ def get_batch(context_tokens: torch.LongTensor, eod_id: int):
|
||||||
)
|
)
|
||||||
return tokens, attention_mask, position_ids
|
return tokens, attention_mask, position_ids
|
||||||
|
|
||||||
|
|
||||||
def make_context(
|
def make_context(
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
query: str,
|
query: str,
|
||||||
query_assistant: str = "",
|
query_assistant: str = "",
|
||||||
history: List[Tuple[str, str]] = None,
|
history: List[Tuple[str, str]] = None,
|
||||||
system: str = "",
|
system: str = "",
|
||||||
max_window_size: int = 6144
|
max_window_size: int = 6144,
|
||||||
):
|
):
|
||||||
if history is None:
|
if history is None:
|
||||||
history = []
|
history = []
|
||||||
|
@ -122,9 +122,9 @@ def make_context(
|
||||||
nl_tokens = tokenizer.encode("\n")
|
nl_tokens = tokenizer.encode("\n")
|
||||||
|
|
||||||
def _tokenize_str(role, content):
|
def _tokenize_str(role, content):
|
||||||
return f"{role}\n{content}", tokenizer.encode(
|
return f"{role}\n{content}", tokenizer.encode(role, allowed_special=set()) + nl_tokens + tokenizer.encode(
|
||||||
role, allowed_special=set()
|
content, allowed_special=set()
|
||||||
) + nl_tokens + tokenizer.encode(content, allowed_special=set())
|
)
|
||||||
|
|
||||||
system_text, system_tokens_part = _tokenize_str("system", system)
|
system_text, system_tokens_part = _tokenize_str("system", system)
|
||||||
system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
|
system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
|
||||||
|
@ -135,19 +135,13 @@ def make_context(
|
||||||
for turn_query, turn_response in reversed(history):
|
for turn_query, turn_response in reversed(history):
|
||||||
query_text, query_tokens_part = _tokenize_str("user", turn_query)
|
query_text, query_tokens_part = _tokenize_str("user", turn_query)
|
||||||
query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
|
query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
|
||||||
response_text, response_tokens_part = _tokenize_str(
|
response_text, response_tokens_part = _tokenize_str("assistant", turn_response)
|
||||||
"assistant", turn_response
|
|
||||||
)
|
|
||||||
response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
|
response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
|
||||||
|
|
||||||
next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
|
next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
|
||||||
prev_chat = (
|
prev_chat = f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
|
||||||
f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
|
|
||||||
)
|
|
||||||
|
|
||||||
current_context_size = (
|
current_context_size = len(system_tokens) + len(next_context_tokens) + len(context_tokens)
|
||||||
len(system_tokens) + len(next_context_tokens) + len(context_tokens)
|
|
||||||
)
|
|
||||||
if current_context_size < max_window_size:
|
if current_context_size < max_window_size:
|
||||||
context_tokens = next_context_tokens + context_tokens
|
context_tokens = next_context_tokens + context_tokens
|
||||||
raw_text = prev_chat + raw_text
|
raw_text = prev_chat + raw_text
|
||||||
|
@ -171,11 +165,12 @@ def make_context(
|
||||||
|
|
||||||
return raw_text, context_tokens
|
return raw_text, context_tokens
|
||||||
|
|
||||||
|
|
||||||
def decode_tokens(
|
def decode_tokens(
|
||||||
tokens: Union[torch.LongTensor, TokensType],
|
tokens: Union[torch.LongTensor, TokensType],
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
raw_text_len: int,
|
raw_text_len: int = 0,
|
||||||
context_length: int,
|
context_length: int = 0,
|
||||||
errors: str = "replace",
|
errors: str = "replace",
|
||||||
) -> str:
|
) -> str:
|
||||||
if torch.is_tensor(tokens):
|
if torch.is_tensor(tokens):
|
||||||
|
@ -211,42 +206,26 @@ class StopWordsLogitsProcessor(LogitsProcessor):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int):
|
def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int):
|
||||||
|
|
||||||
if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0:
|
if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0:
|
||||||
raise ValueError(
|
raise ValueError(f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}.")
|
||||||
f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}."
|
|
||||||
)
|
|
||||||
if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids):
|
if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids):
|
||||||
raise ValueError(
|
raise ValueError(f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}.")
|
||||||
f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}."
|
|
||||||
)
|
|
||||||
if any(
|
if any(
|
||||||
any(
|
any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in stop_word_ids)
|
||||||
(not isinstance(token_id, (int, np.integer)) or token_id < 0)
|
|
||||||
for token_id in stop_word_ids
|
|
||||||
)
|
|
||||||
for stop_word_ids in stop_words_ids
|
for stop_word_ids in stop_words_ids
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}."
|
f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.stop_words_ids = list(
|
self.stop_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids))
|
||||||
filter(
|
|
||||||
lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.eos_token_id = eos_token_id
|
self.eos_token_id = eos_token_id
|
||||||
for stop_token_seq in self.stop_words_ids:
|
for stop_token_seq in self.stop_words_ids:
|
||||||
assert (
|
assert len(stop_token_seq) > 0, "Stop words token sequences {} cannot have an empty list".format(
|
||||||
len(stop_token_seq) > 0
|
|
||||||
), "Stop words token sequences {} cannot have an empty list".format(
|
|
||||||
stop_words_ids
|
stop_words_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
|
||||||
) -> torch.FloatTensor:
|
|
||||||
stopped_samples = self._calc_stopped_samples(input_ids)
|
stopped_samples = self._calc_stopped_samples(input_ids)
|
||||||
for i, should_stop in enumerate(stopped_samples):
|
for i, should_stop in enumerate(stopped_samples):
|
||||||
if should_stop:
|
if should_stop:
|
||||||
|
|
|
@ -10,6 +10,11 @@ from modeling_qwen import QwenRunner
|
||||||
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from qwen_generation_utils import (
|
||||||
|
make_context,
|
||||||
|
decode_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
sys.path.append("..")
|
sys.path.append("..")
|
||||||
from tools import show
|
from tools import show
|
||||||
|
|
||||||
|
@ -32,10 +37,39 @@ model = QWenLMHeadModel(config)
|
||||||
print(model)
|
print(model)
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
||||||
model = model.from_pretrained(model_dir).cuda()
|
|
||||||
|
model = model.from_pretrained(model_dir)
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
|
|
||||||
|
|
||||||
|
def Dump_tokens_list(model):
|
||||||
|
tokens = []
|
||||||
|
for token in range(config.eos_token_id):
|
||||||
|
decoded, response, end_reason = decode_tokens(
|
||||||
|
[token],
|
||||||
|
tokenizer,
|
||||||
|
raw_text_len=0,
|
||||||
|
context_length=0,
|
||||||
|
errors="replace",
|
||||||
|
)
|
||||||
|
tokens.append(str(token) + ": " + decoded)
|
||||||
|
show.DumpListToFile(tokens, "./temp/qwen_token_list.txt")
|
||||||
|
|
||||||
|
|
||||||
|
# Dump_tokens_list(model)
|
||||||
|
|
||||||
|
|
||||||
|
def Dump_lm_head_weight(model):
|
||||||
|
weight = model.lm_head.weight.cpu() # [151936,2048,]
|
||||||
|
weight = weight.reshape(64, -1, 2048)
|
||||||
|
for i in range(64):
|
||||||
|
sub = weight[i].reshape(-1, 64, 32)
|
||||||
|
show.DumpTensorToImage(sub, "./temp/lm_head_" + str(i) + "_2374_2048.png")
|
||||||
|
|
||||||
|
|
||||||
|
# Dump_lm_head_weight(model)
|
||||||
|
|
||||||
|
|
||||||
class ResearchRunner(QwenRunner):
|
class ResearchRunner(QwenRunner):
|
||||||
def attention(self, attention, query, key, value, causal_mask):
|
def attention(self, attention, query, key, value, causal_mask):
|
||||||
query = query.permute(0, 2, 1, 3)
|
query = query.permute(0, 2, 1, 3)
|
||||||
|
@ -50,8 +84,8 @@ class ResearchRunner(QwenRunner):
|
||||||
attn_mask.masked_fill_(causal_mask.logical_not(), float(0))
|
attn_mask.masked_fill_(causal_mask.logical_not(), float(0))
|
||||||
qk = attn_weight * attn_mask
|
qk = attn_weight * attn_mask
|
||||||
qk = qk[0]
|
qk = qk[0]
|
||||||
prePath = "./temp/"
|
prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(attention.index) + ".png"
|
||||||
show.DumpTensorToImage(qk, prePath + "q@k_seq_" + str(size) + "_layer_" + str(attention.index) + ".png")
|
show.DumpTensorToImage(qk, prePath, GridValue=255)
|
||||||
|
|
||||||
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=causal_mask).transpose(1, 2)
|
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=causal_mask).transpose(1, 2)
|
||||||
context_layer = attention._merge_heads(attn_output, attention.num_heads, attention.head_dim)
|
context_layer = attention._merge_heads(attn_output, attention.num_heads, attention.head_dim)
|
||||||
|
@ -71,6 +105,6 @@ print(decode_tokens)
|
||||||
# <|im_start|>assistant
|
# <|im_start|>assistant
|
||||||
# 日本的首都东京。<|im_end|>
|
# 日本的首都东京。<|im_end|>
|
||||||
# <|endoftext|>
|
# <|endoftext|>
|
||||||
# 日本的首都是东京。<|im_end|>
|
|
||||||
if decode_tokens.split("\n")[-2] != """日本的首都东京。<|im_end|>""":
|
if decode_tokens.split("\n")[-2] != """日本的首都东京。<|im_end|>""":
|
||||||
raise ()
|
raise ()
|
||||||
|
|
|
@ -1,8 +1,16 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
x = torch.tensor([[1, 2], [3, 4]])
|
x1 = torch.tensor([[1, 2]], dtype=float)
|
||||||
|
x2 = torch.tensor([[5, 6], [7, 8]], dtype=float)
|
||||||
|
|
||||||
|
y = x1 @ x2 # torch.matmul(x1 , x2)
|
||||||
|
x_inverse = torch.inverse(x2, out=None)
|
||||||
|
y_inverse = y @ x_inverse
|
||||||
|
y_inverse = y_inverse.permute(1, 0)
|
||||||
|
|
||||||
|
|
||||||
|
x = torch.tensor([[1, 2], [3, 4]], dtype=float)
|
||||||
print(x)
|
print(x)
|
||||||
print("x.tile((2)) -> ", x.tile((2)).shape)
|
print("x.tile((2)) -> ", x.tile((2)).shape)
|
||||||
print(x.tile((2)))
|
print(x.tile((2)))
|
||||||
|
@ -55,4 +63,3 @@ z = torch.ones([1, 32, 6, 128])
|
||||||
att = torch.matmul(x, y)
|
att = torch.matmul(x, y)
|
||||||
mm = torch.matmul(att, z)
|
mm = torch.matmul(att, z)
|
||||||
print(mm.shape)
|
print(mm.shape)
|
||||||
|
|
||||||
|
|
|
@ -8,23 +8,24 @@ import numpy as np
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
def DumpTensorToImage(tensor, name, forceSquare=True, scale=1.0, AutoContrast=True):
|
def DumpTensorToImage(tensor, name, forceSquare=True, scale=1.0, AutoContrast=True, GridValue=0):
|
||||||
if len(tensor.shape) != 2 and len(tensor.shape) != 1 and len(tensor.shape) != 3:
|
if len(tensor.shape) != 2 and len(tensor.shape) != 1 and len(tensor.shape) != 3:
|
||||||
raise ("Error input dims")
|
raise ("Error input dims")
|
||||||
|
|
||||||
if len(tensor.shape) == 3:
|
if len(tensor.shape) == 3:
|
||||||
channel = tensor.shape[0]
|
channel = tensor.shape[0]
|
||||||
x = math.ceil((channel) ** 0.5)
|
x = math.ceil((channel) ** 0.5)
|
||||||
tensor = F.pad(tensor, (0, 1, 0, 1, 0, x * x - channel), mode="constant", value=0)
|
tensor = F.pad(tensor, (0, 0, 0, 0, 0, x * x - channel), mode="constant", value=0)
|
||||||
if AutoContrast:
|
|
||||||
calc = tensor.reshape((x * x, tensor.shape[1] * tensor.shape[2]))
|
calc = tensor.reshape((x * x, tensor.shape[1] * tensor.shape[2]))
|
||||||
|
if AutoContrast:
|
||||||
tensormax = calc.max(1)[0]
|
tensormax = calc.max(1)[0]
|
||||||
tensormin = calc.min(1)[0]
|
tensormin = calc.min(1)[0]
|
||||||
calc = calc.transpose(1, 0)
|
calc = calc.transpose(1, 0)
|
||||||
calc = ((calc - tensormin) / (tensormax - tensormin)) * 255
|
calc = ((calc - tensormin) / (tensormax - tensormin)) * 255
|
||||||
calc = calc.transpose(1, 0)
|
calc = calc.transpose(1, 0)
|
||||||
tensor = calc.reshape((x, x, tensor.shape[1], tensor.shape[2]))
|
calc = calc.reshape((x, x, tensor.shape[1], tensor.shape[2]))
|
||||||
tensor = tensor.permute((0, 2, 1, 3))
|
calc = F.pad(calc, (0, 1, 0, 1, 0, 0), mode="constant", value=GridValue)
|
||||||
|
tensor = calc.permute((0, 2, 1, 3))
|
||||||
tensor = tensor.reshape((x * tensor.shape[1], x * tensor.shape[3]))
|
tensor = tensor.reshape((x * tensor.shape[1], x * tensor.shape[3]))
|
||||||
DumpTensorToImage(tensor, name, forceSquare=False, scale=scale, AutoContrast=False)
|
DumpTensorToImage(tensor, name, forceSquare=False, scale=scale, AutoContrast=False)
|
||||||
return
|
return
|
||||||
|
|
|
@ -12,7 +12,9 @@ show.DumpTensorToImage(radata, "test.png")
|
||||||
|
|
||||||
|
|
||||||
radata = torch.randn(3, 127, 127)
|
radata = torch.randn(3, 127, 127)
|
||||||
show.DumpTensorToImage(radata, "test.png")
|
show.DumpTensorToImage(radata, "test1.png", AutoContrast=True, GridValue=0)
|
||||||
|
show.DumpTensorToImage(radata, "test2.png", AutoContrast=True, GridValue=255)
|
||||||
|
show.DumpTensorToImage(radata, "test3.png", AutoContrast=False, GridValue=0)
|
||||||
|
|
||||||
|
|
||||||
radata = torch.randn(127, 127)
|
radata = torch.randn(127, 127)
|
||||||
|
|
Loading…
Reference in New Issue