Add dump tool.
This commit is contained in:
parent
a451def299
commit
68417fdc12
|
@ -502,13 +502,9 @@ class ChatGLMModel(nn.Module):
|
||||||
# Rotary positional embeddings
|
# Rotary positional embeddings
|
||||||
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
|
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
|
||||||
|
|
||||||
|
from tools import show
|
||||||
|
|
||||||
import plotly_express as px
|
show.DumpTensorToImage(rotary_pos_emb[:, :, 0], "plot.png", scale=0.1)
|
||||||
img = px.imshow((rotary_pos_emb[:,:,0]*256).byte().cpu())
|
|
||||||
img.write_image("plot.png")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if position_ids is not None:
|
if position_ids is not None:
|
||||||
rotary_pos_emb = rotary_pos_emb[position_ids]
|
rotary_pos_emb = rotary_pos_emb[position_ids]
|
||||||
|
@ -709,10 +705,10 @@ class ChatGLMForConditionalGeneration(nn.Module):
|
||||||
input_ids,
|
input_ids,
|
||||||
pad_token_id=generation_config.pad_token_id,
|
pad_token_id=generation_config.pad_token_id,
|
||||||
eos_token_id=generation_config.eos_token_id,
|
eos_token_id=generation_config.eos_token_id,
|
||||||
output_hidden_states = generation_config.output_hidden_states,
|
output_hidden_states=generation_config.output_hidden_states,
|
||||||
use_cache = generation_config.use_cache
|
use_cache=generation_config.use_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) : -1]
|
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) : -1]
|
||||||
response = tokenizer.decode(outputs)
|
response = tokenizer.decode(outputs)
|
||||||
history.append({"role": role, "content": query})
|
history.append({"role": role, "content": query})
|
||||||
|
@ -724,7 +720,7 @@ class ChatGLMForConditionalGeneration(nn.Module):
|
||||||
pad_token_id: Optional[int] = None,
|
pad_token_id: Optional[int] = None,
|
||||||
eos_token_id: Optional[Union[int, List[int]]] = None,
|
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
use_cache: Optional[bool] = None
|
use_cache: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
if isinstance(eos_token_id, int):
|
if isinstance(eos_token_id, int):
|
||||||
eos_token_id = [eos_token_id]
|
eos_token_id = [eos_token_id]
|
||||||
|
|
BIN
plot.png
BIN
plot.png
Binary file not shown.
Before Width: | Height: | Size: 26 KiB After Width: | Height: | Size: 262 KiB |
|
@ -0,0 +1,52 @@
|
||||||
|
import plotly_express as px
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torchvision.transforms.functional as Vision
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
|
||||||
|
def DumpTensorToImage(tensor, name, autoPad=True, scale=1.0):
|
||||||
|
if len(tensor.shape) != 2:
|
||||||
|
raise ("Error input dims")
|
||||||
|
tensor = tensor.float()
|
||||||
|
maxv = torch.max(tensor)
|
||||||
|
minv = torch.min(tensor)
|
||||||
|
tensor = (((tensor - minv) / (maxv - minv)) * 256).byte().cpu()
|
||||||
|
img = tensor.numpy()
|
||||||
|
srp = img.shape
|
||||||
|
if autoPad and (max(srp) / min(srp) > 16):
|
||||||
|
img = cv2.resize(img,[max(srp),max(srp)])
|
||||||
|
srp = img.shape
|
||||||
|
if scale != 1.0:
|
||||||
|
img = cv2.resize(img, [int(srp[0] * scale), int(srp[1] * scale)])
|
||||||
|
srp = img.shape
|
||||||
|
cv2.imwrite(name, img)
|
||||||
|
|
||||||
|
|
||||||
|
# def DumpTensorToImage(tensor, name, autoPad=True, scale=1.0):
|
||||||
|
# if len(tensor.shape) != 2:
|
||||||
|
# raise ("Error input dims")
|
||||||
|
# tensor = tensor.float()
|
||||||
|
# maxv = torch.max(tensor)
|
||||||
|
# minv = torch.min(tensor)
|
||||||
|
# tensor = (((tensor - minv) / (maxv - minv)) * 256).byte().cpu()
|
||||||
|
# srp = tensor.shape
|
||||||
|
# if autoPad and (max(srp) / min(srp) > 16):
|
||||||
|
# if srp[0] == min(srp):
|
||||||
|
# tensor = F.pad(tensor, [max(srp) - min(srp), 0], "replicate")
|
||||||
|
# else:
|
||||||
|
# tensor = F.pad(tensor, [0, max(srp) - min(srp)], "replicate")
|
||||||
|
# srp = tensor.shape
|
||||||
|
|
||||||
|
# tensor = tensor.unsqueeze(0)
|
||||||
|
# if scale != 1.0:
|
||||||
|
# tensor = Vision.resize(tensor, [int(srp[0] * scale), int(srp[1] * scale)])
|
||||||
|
# tensor = tensor.view([int(srp[0] * scale), int(srp[1] * scale)])
|
||||||
|
# srp = tensor.shape
|
||||||
|
|
||||||
|
# w = 1024 if max(srp) > 1024 else max(srp)
|
||||||
|
# scale = max(srp) / w
|
||||||
|
# # img = px.imshow(tensor)
|
||||||
|
# # img.write_image(name)
|
||||||
|
# cv2.imwrite(name, tensor.numpy())
|
||||||
|
# cv2.CreateMat(name, tensor.numpy())
|
|
@ -0,0 +1,6 @@
|
||||||
|
import show
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
radata = torch.randn(8192, 128)
|
||||||
|
show.DumpTensorToImage(radata, "test.png", autoPad=True,scale=0.2)
|
Loading…
Reference in New Issue