Compare commits
No commits in common. "2ad977a072b90fe3e9378717a9f4fd7af324eb00" and "5b2cd4da61f0d2ab19f8a18393edabe1eb19e593" have entirely different histories.
2ad977a072
...
5b2cd4da61
|
@ -78,12 +78,11 @@ def DumpTensorToImage(tensor, name, forceSquare=False, scale=1.0, Contrast=None,
|
||||||
|
|
||||||
|
|
||||||
def DumpTensorToLog(tensor, name="log"):
|
def DumpTensorToLog(tensor, name="log"):
|
||||||
|
shape = tensor.shape
|
||||||
tensor_mean = torch.mean(tensor).cpu().detach().numpy()
|
tensor_mean = torch.mean(tensor).cpu().detach().numpy()
|
||||||
tensor_abs_mean = torch.mean(torch.abs(tensor)).cpu().detach().numpy()
|
|
||||||
tensor_range = (torch.max(tensor) - torch.min(tensor)).cpu().detach().numpy()
|
tensor_range = (torch.max(tensor) - torch.min(tensor)).cpu().detach().numpy()
|
||||||
f = open(name, "w")
|
f = open(name, "w")
|
||||||
f.writelines("tensor mean: %s" % tensor_mean + os.linesep)
|
f.writelines("tensor mean: %s" % tensor_mean + os.linesep)
|
||||||
f.writelines("tensor abs mean: %s" % tensor_abs_mean + os.linesep)
|
|
||||||
f.writelines("tensor range: %s" % tensor_range + os.linesep)
|
f.writelines("tensor range: %s" % tensor_range + os.linesep)
|
||||||
data = tensor.reshape([-1]).float().cpu().detach().numpy().tolist()
|
data = tensor.reshape([-1]).float().cpu().detach().numpy().tolist()
|
||||||
for d in data:
|
for d in data:
|
||||||
|
|
|
@ -20,8 +20,8 @@ np.random.seed(seed)
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
|
|
||||||
|
|
||||||
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
device = torch.device("cpu")
|
# device = torch.device("cpu")
|
||||||
# device = torch.device("mps")
|
# device = torch.device("mps")
|
||||||
|
|
||||||
num_epochs = 1
|
num_epochs = 1
|
||||||
|
@ -128,8 +128,8 @@ for epoch in range(epochs):
|
||||||
all = range(0, sample.shape[0])
|
all = range(0, sample.shape[0])
|
||||||
# ratio_max = abs / mean
|
# ratio_max = abs / mean
|
||||||
# ratio_nor = (max - abs) / max
|
# ratio_nor = (max - abs) / max
|
||||||
# ratio_nor[all, max_index] = ratio_max[all, max_index].clone()
|
|
||||||
ratio_nor = torch.pow(abs / mean, 4)
|
ratio_nor = torch.pow(abs / mean, 4)
|
||||||
|
# ratio_nor[all, max_index] = ratio_max[all, max_index].clone()
|
||||||
ratio_nor = torch.where(torch.isnan(ratio_nor), 1.0, ratio_nor)
|
ratio_nor = torch.where(torch.isnan(ratio_nor), 1.0, ratio_nor)
|
||||||
label = sample * ratio_nor
|
label = sample * ratio_nor
|
||||||
|
|
||||||
|
@ -197,6 +197,7 @@ with torch.no_grad():
|
||||||
labels = labels.to(device)
|
labels = labels.to(device)
|
||||||
outputs = model(images)
|
outputs = model(images)
|
||||||
|
|
||||||
|
# max returns (value ,index)
|
||||||
_, predicted = torch.max(outputs.data, 1)
|
_, predicted = torch.max(outputs.data, 1)
|
||||||
n_samples += labels.size(0)
|
n_samples += labels.size(0)
|
||||||
n_correct += (predicted == labels).sum().item()
|
n_correct += (predicted == labels).sum().item()
|
||||||
|
|
Loading…
Reference in New Issue