Refine research_attention.
This commit is contained in:
		
							parent
							
								
									5dbac40925
								
							
						
					
					
						commit
						1811b9611a
					
				| 
						 | 
					@ -47,9 +47,9 @@ def get_ltor_masks_and_position_ids(
 | 
				
			||||||
        att_mask_batch = micro_batch_size
 | 
					        att_mask_batch = micro_batch_size
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        att_mask_batch = 1
 | 
					        att_mask_batch = 1
 | 
				
			||||||
    attention_mask = torch.tril(
 | 
					    attention_mask = torch.tril(torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)).view(
 | 
				
			||||||
        torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)
 | 
					        att_mask_batch, 1, seq_length, seq_length
 | 
				
			||||||
    ).view(att_mask_batch, 1, seq_length, seq_length)
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Loss mask.
 | 
					    # Loss mask.
 | 
				
			||||||
    loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
 | 
					    loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
 | 
				
			||||||
| 
						 | 
					@ -66,7 +66,6 @@ def get_ltor_masks_and_position_ids(
 | 
				
			||||||
    if reset_position_ids or reset_attention_mask:
 | 
					    if reset_position_ids or reset_attention_mask:
 | 
				
			||||||
        # Loop through the batches:
 | 
					        # Loop through the batches:
 | 
				
			||||||
        for b in range(micro_batch_size):
 | 
					        for b in range(micro_batch_size):
 | 
				
			||||||
 | 
					 | 
				
			||||||
            # Find indecies where EOD token is.
 | 
					            # Find indecies where EOD token is.
 | 
				
			||||||
            eod_index = position_ids[b, data[b] == eod_token]
 | 
					            eod_index = position_ids[b, data[b] == eod_token]
 | 
				
			||||||
            # Detach indecies from positions if going to modify positions.
 | 
					            # Detach indecies from positions if going to modify positions.
 | 
				
			||||||
| 
						 | 
					@ -105,13 +104,14 @@ def get_batch(context_tokens: torch.LongTensor, eod_id: int):
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    return tokens, attention_mask, position_ids
 | 
					    return tokens, attention_mask, position_ids
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def make_context(
 | 
					def make_context(
 | 
				
			||||||
    tokenizer: PreTrainedTokenizer,
 | 
					    tokenizer: PreTrainedTokenizer,
 | 
				
			||||||
    query: str,
 | 
					    query: str,
 | 
				
			||||||
    query_assistant: str = "",
 | 
					    query_assistant: str = "",
 | 
				
			||||||
    history: List[Tuple[str, str]] = None,
 | 
					    history: List[Tuple[str, str]] = None,
 | 
				
			||||||
    system: str = "",
 | 
					    system: str = "",
 | 
				
			||||||
    max_window_size: int = 6144
 | 
					    max_window_size: int = 6144,
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    if history is None:
 | 
					    if history is None:
 | 
				
			||||||
        history = []
 | 
					        history = []
 | 
				
			||||||
| 
						 | 
					@ -122,9 +122,9 @@ def make_context(
 | 
				
			||||||
    nl_tokens = tokenizer.encode("\n")
 | 
					    nl_tokens = tokenizer.encode("\n")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _tokenize_str(role, content):
 | 
					    def _tokenize_str(role, content):
 | 
				
			||||||
        return f"{role}\n{content}", tokenizer.encode(
 | 
					        return f"{role}\n{content}", tokenizer.encode(role, allowed_special=set()) + nl_tokens + tokenizer.encode(
 | 
				
			||||||
            role, allowed_special=set()
 | 
					            content, allowed_special=set()
 | 
				
			||||||
        ) + nl_tokens + tokenizer.encode(content, allowed_special=set())
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    system_text, system_tokens_part = _tokenize_str("system", system)
 | 
					    system_text, system_tokens_part = _tokenize_str("system", system)
 | 
				
			||||||
    system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
 | 
					    system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
 | 
				
			||||||
| 
						 | 
					@ -135,19 +135,13 @@ def make_context(
 | 
				
			||||||
    for turn_query, turn_response in reversed(history):
 | 
					    for turn_query, turn_response in reversed(history):
 | 
				
			||||||
        query_text, query_tokens_part = _tokenize_str("user", turn_query)
 | 
					        query_text, query_tokens_part = _tokenize_str("user", turn_query)
 | 
				
			||||||
        query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
 | 
					        query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
 | 
				
			||||||
        response_text, response_tokens_part = _tokenize_str(
 | 
					        response_text, response_tokens_part = _tokenize_str("assistant", turn_response)
 | 
				
			||||||
            "assistant", turn_response
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
 | 
					        response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
 | 
					        next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
 | 
				
			||||||
        prev_chat = (
 | 
					        prev_chat = f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
 | 
				
			||||||
            f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        current_context_size = (
 | 
					        current_context_size = len(system_tokens) + len(next_context_tokens) + len(context_tokens)
 | 
				
			||||||
            len(system_tokens) + len(next_context_tokens) + len(context_tokens)
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        if current_context_size < max_window_size:
 | 
					        if current_context_size < max_window_size:
 | 
				
			||||||
            context_tokens = next_context_tokens + context_tokens
 | 
					            context_tokens = next_context_tokens + context_tokens
 | 
				
			||||||
            raw_text = prev_chat + raw_text
 | 
					            raw_text = prev_chat + raw_text
 | 
				
			||||||
| 
						 | 
					@ -171,12 +165,13 @@ def make_context(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return raw_text, context_tokens
 | 
					    return raw_text, context_tokens
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def decode_tokens(
 | 
					def decode_tokens(
 | 
				
			||||||
    tokens: Union[torch.LongTensor, TokensType],
 | 
					    tokens: Union[torch.LongTensor, TokensType],
 | 
				
			||||||
    tokenizer: PreTrainedTokenizer,
 | 
					    tokenizer: PreTrainedTokenizer,
 | 
				
			||||||
    raw_text_len: int,
 | 
					    raw_text_len: int = 0,
 | 
				
			||||||
    context_length: int,
 | 
					    context_length: int = 0,
 | 
				
			||||||
    errors: str="replace",
 | 
					    errors: str = "replace",
 | 
				
			||||||
) -> str:
 | 
					) -> str:
 | 
				
			||||||
    if torch.is_tensor(tokens):
 | 
					    if torch.is_tensor(tokens):
 | 
				
			||||||
        tokens = tokens.cpu().numpy().tolist()
 | 
					        tokens = tokens.cpu().numpy().tolist()
 | 
				
			||||||
| 
						 | 
					@ -211,42 +206,26 @@ class StopWordsLogitsProcessor(LogitsProcessor):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int):
 | 
					    def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int):
 | 
				
			||||||
 | 
					 | 
				
			||||||
        if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0:
 | 
					        if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0:
 | 
				
			||||||
            raise ValueError(
 | 
					            raise ValueError(f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}.")
 | 
				
			||||||
                f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}."
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids):
 | 
					        if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids):
 | 
				
			||||||
            raise ValueError(
 | 
					            raise ValueError(f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}.")
 | 
				
			||||||
                f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}."
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        if any(
 | 
					        if any(
 | 
				
			||||||
            any(
 | 
					            any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in stop_word_ids)
 | 
				
			||||||
                (not isinstance(token_id, (int, np.integer)) or token_id < 0)
 | 
					 | 
				
			||||||
                for token_id in stop_word_ids
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            for stop_word_ids in stop_words_ids
 | 
					            for stop_word_ids in stop_words_ids
 | 
				
			||||||
        ):
 | 
					        ):
 | 
				
			||||||
            raise ValueError(
 | 
					            raise ValueError(
 | 
				
			||||||
                f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}."
 | 
					                f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}."
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.stop_words_ids = list(
 | 
					        self.stop_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids))
 | 
				
			||||||
            filter(
 | 
					 | 
				
			||||||
                lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.eos_token_id = eos_token_id
 | 
					        self.eos_token_id = eos_token_id
 | 
				
			||||||
        for stop_token_seq in self.stop_words_ids:
 | 
					        for stop_token_seq in self.stop_words_ids:
 | 
				
			||||||
            assert (
 | 
					            assert len(stop_token_seq) > 0, "Stop words token sequences {} cannot have an empty list".format(
 | 
				
			||||||
                len(stop_token_seq) > 0
 | 
					 | 
				
			||||||
            ), "Stop words token sequences {} cannot have an empty list".format(
 | 
					 | 
				
			||||||
                stop_words_ids
 | 
					                stop_words_ids
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __call__(
 | 
					    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
 | 
				
			||||||
        self, input_ids: torch.LongTensor, scores: torch.FloatTensor
 | 
					 | 
				
			||||||
    ) -> torch.FloatTensor:
 | 
					 | 
				
			||||||
        stopped_samples = self._calc_stopped_samples(input_ids)
 | 
					        stopped_samples = self._calc_stopped_samples(input_ids)
 | 
				
			||||||
        for i, should_stop in enumerate(stopped_samples):
 | 
					        for i, should_stop in enumerate(stopped_samples):
 | 
				
			||||||
            if should_stop:
 | 
					            if should_stop:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -10,6 +10,11 @@ from modeling_qwen import QwenRunner
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch.nn.functional as F
 | 
					import torch.nn.functional as F
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from qwen_generation_utils import (
 | 
				
			||||||
 | 
					    make_context,
 | 
				
			||||||
 | 
					    decode_tokens,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
sys.path.append("..")
 | 
					sys.path.append("..")
 | 
				
			||||||
from tools import show
 | 
					from tools import show
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -32,10 +37,39 @@ model = QWenLMHeadModel(config)
 | 
				
			||||||
print(model)
 | 
					print(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
 | 
					tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
 | 
				
			||||||
model = model.from_pretrained(model_dir).cuda()
 | 
					
 | 
				
			||||||
 | 
					model = model.from_pretrained(model_dir)
 | 
				
			||||||
model = model.eval()
 | 
					model = model.eval()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def Dump_tokens_list(model):
 | 
				
			||||||
 | 
					    tokens = []
 | 
				
			||||||
 | 
					    for token in range(config.eos_token_id):
 | 
				
			||||||
 | 
					        decoded, response, end_reason = decode_tokens(
 | 
				
			||||||
 | 
					            [token],
 | 
				
			||||||
 | 
					            tokenizer,
 | 
				
			||||||
 | 
					            raw_text_len=0,
 | 
				
			||||||
 | 
					            context_length=0,
 | 
				
			||||||
 | 
					            errors="replace",
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        tokens.append(str(token) + ": " + decoded)
 | 
				
			||||||
 | 
					    show.DumpListToFile(tokens, "./temp/qwen_token_list.txt")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Dump_tokens_list(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def Dump_lm_head_weight(model):
 | 
				
			||||||
 | 
					    weight = model.lm_head.weight.cpu()  # [151936,2048,]
 | 
				
			||||||
 | 
					    weight = weight.reshape(64, -1, 2048)
 | 
				
			||||||
 | 
					    for i in range(64):
 | 
				
			||||||
 | 
					        sub = weight[i].reshape(-1, 64, 32)
 | 
				
			||||||
 | 
					        show.DumpTensorToImage(sub, "./temp/lm_head_" + str(i) + "_2374_2048.png")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Dump_lm_head_weight(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ResearchRunner(QwenRunner):
 | 
					class ResearchRunner(QwenRunner):
 | 
				
			||||||
    def attention(self, attention, query, key, value, causal_mask):
 | 
					    def attention(self, attention, query, key, value, causal_mask):
 | 
				
			||||||
        query = query.permute(0, 2, 1, 3)
 | 
					        query = query.permute(0, 2, 1, 3)
 | 
				
			||||||
| 
						 | 
					@ -50,8 +84,8 @@ class ResearchRunner(QwenRunner):
 | 
				
			||||||
        attn_mask.masked_fill_(causal_mask.logical_not(), float(0))
 | 
					        attn_mask.masked_fill_(causal_mask.logical_not(), float(0))
 | 
				
			||||||
        qk = attn_weight * attn_mask
 | 
					        qk = attn_weight * attn_mask
 | 
				
			||||||
        qk = qk[0]
 | 
					        qk = qk[0]
 | 
				
			||||||
        prePath = "./temp/"
 | 
					        prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(attention.index) + ".png"
 | 
				
			||||||
        show.DumpTensorToImage(qk, prePath + "q@k_seq_" + str(size) + "_layer_" + str(attention.index) + ".png")
 | 
					        show.DumpTensorToImage(qk, prePath, GridValue=255)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=causal_mask).transpose(1, 2)
 | 
					        attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=causal_mask).transpose(1, 2)
 | 
				
			||||||
        context_layer = attention._merge_heads(attn_output, attention.num_heads, attention.head_dim)
 | 
					        context_layer = attention._merge_heads(attn_output, attention.num_heads, attention.head_dim)
 | 
				
			||||||
| 
						 | 
					@ -71,6 +105,6 @@ print(decode_tokens)
 | 
				
			||||||
# <|im_start|>assistant
 | 
					# <|im_start|>assistant
 | 
				
			||||||
# 日本的首都东京。<|im_end|>
 | 
					# 日本的首都东京。<|im_end|>
 | 
				
			||||||
# <|endoftext|>
 | 
					# <|endoftext|>
 | 
				
			||||||
                                    #    日本的首都是东京。<|im_end|>
 | 
					
 | 
				
			||||||
if decode_tokens.split("\n")[-2] != """日本的首都东京。<|im_end|>""":
 | 
					if decode_tokens.split("\n")[-2] != """日本的首都东京。<|im_end|>""":
 | 
				
			||||||
    raise ()
 | 
					    raise ()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,8 +1,16 @@
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import torch.nn.functional as F
 | 
					import torch.nn.functional as F
 | 
				
			||||||
 | 
					
 | 
				
			||||||
x = torch.tensor([[1, 2], [3, 4]])
 | 
					x1 = torch.tensor([[1, 2]], dtype=float)
 | 
				
			||||||
 | 
					x2 = torch.tensor([[5, 6], [7, 8]], dtype=float)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					y = x1 @ x2  # torch.matmul(x1 , x2)
 | 
				
			||||||
 | 
					x_inverse = torch.inverse(x2, out=None)
 | 
				
			||||||
 | 
					y_inverse = y @ x_inverse
 | 
				
			||||||
 | 
					y_inverse = y_inverse.permute(1, 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					x = torch.tensor([[1, 2], [3, 4]], dtype=float)
 | 
				
			||||||
print(x)
 | 
					print(x)
 | 
				
			||||||
print("x.tile((2)) -> ", x.tile((2)).shape)
 | 
					print("x.tile((2)) -> ", x.tile((2)).shape)
 | 
				
			||||||
print(x.tile((2)))
 | 
					print(x.tile((2)))
 | 
				
			||||||
| 
						 | 
					@ -55,4 +63,3 @@ z = torch.ones([1, 32, 6, 128])
 | 
				
			||||||
att = torch.matmul(x, y)
 | 
					att = torch.matmul(x, y)
 | 
				
			||||||
mm = torch.matmul(att, z)
 | 
					mm = torch.matmul(att, z)
 | 
				
			||||||
print(mm.shape)
 | 
					print(mm.shape)
 | 
				
			||||||
 | 
					 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -8,23 +8,24 @@ import numpy as np
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def DumpTensorToImage(tensor, name, forceSquare=True, scale=1.0, AutoContrast=True):
 | 
					def DumpTensorToImage(tensor, name, forceSquare=True, scale=1.0, AutoContrast=True, GridValue=0):
 | 
				
			||||||
    if len(tensor.shape) != 2 and len(tensor.shape) != 1 and len(tensor.shape) != 3:
 | 
					    if len(tensor.shape) != 2 and len(tensor.shape) != 1 and len(tensor.shape) != 3:
 | 
				
			||||||
        raise ("Error input dims")
 | 
					        raise ("Error input dims")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if len(tensor.shape) == 3:
 | 
					    if len(tensor.shape) == 3:
 | 
				
			||||||
        channel = tensor.shape[0]
 | 
					        channel = tensor.shape[0]
 | 
				
			||||||
        x = math.ceil((channel) ** 0.5)
 | 
					        x = math.ceil((channel) ** 0.5)
 | 
				
			||||||
        tensor = F.pad(tensor, (0, 1, 0, 1, 0, x * x - channel), mode="constant", value=0)
 | 
					        tensor = F.pad(tensor, (0, 0, 0, 0, 0, x * x - channel), mode="constant", value=0)
 | 
				
			||||||
        if AutoContrast:
 | 
					 | 
				
			||||||
        calc = tensor.reshape((x * x, tensor.shape[1] * tensor.shape[2]))
 | 
					        calc = tensor.reshape((x * x, tensor.shape[1] * tensor.shape[2]))
 | 
				
			||||||
 | 
					        if AutoContrast:
 | 
				
			||||||
            tensormax = calc.max(1)[0]
 | 
					            tensormax = calc.max(1)[0]
 | 
				
			||||||
            tensormin = calc.min(1)[0]
 | 
					            tensormin = calc.min(1)[0]
 | 
				
			||||||
            calc = calc.transpose(1, 0)
 | 
					            calc = calc.transpose(1, 0)
 | 
				
			||||||
            calc = ((calc - tensormin) / (tensormax - tensormin)) * 255
 | 
					            calc = ((calc - tensormin) / (tensormax - tensormin)) * 255
 | 
				
			||||||
            calc = calc.transpose(1, 0)
 | 
					            calc = calc.transpose(1, 0)
 | 
				
			||||||
        tensor = calc.reshape((x, x, tensor.shape[1], tensor.shape[2]))
 | 
					        calc = calc.reshape((x, x, tensor.shape[1], tensor.shape[2]))
 | 
				
			||||||
        tensor = tensor.permute((0, 2, 1, 3))
 | 
					        calc = F.pad(calc, (0, 1, 0, 1, 0, 0), mode="constant", value=GridValue)
 | 
				
			||||||
 | 
					        tensor = calc.permute((0, 2, 1, 3))
 | 
				
			||||||
        tensor = tensor.reshape((x * tensor.shape[1], x * tensor.shape[3]))
 | 
					        tensor = tensor.reshape((x * tensor.shape[1], x * tensor.shape[3]))
 | 
				
			||||||
        DumpTensorToImage(tensor, name, forceSquare=False, scale=scale, AutoContrast=False)
 | 
					        DumpTensorToImage(tensor, name, forceSquare=False, scale=scale, AutoContrast=False)
 | 
				
			||||||
        return
 | 
					        return
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -12,7 +12,9 @@ show.DumpTensorToImage(radata, "test.png")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
radata = torch.randn(3, 127, 127)
 | 
					radata = torch.randn(3, 127, 127)
 | 
				
			||||||
show.DumpTensorToImage(radata, "test.png")
 | 
					show.DumpTensorToImage(radata, "test1.png", AutoContrast=True, GridValue=0)
 | 
				
			||||||
 | 
					show.DumpTensorToImage(radata, "test2.png", AutoContrast=True, GridValue=255)
 | 
				
			||||||
 | 
					show.DumpTensorToImage(radata, "test3.png", AutoContrast=False, GridValue=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
radata = torch.randn(127, 127)
 | 
					radata = torch.randn(127, 127)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue