Refine minist unsuper.

This commit is contained in:
Colin 2024-10-28 18:52:28 +08:00
parent 5b2cd4da61
commit 59449df047
1 changed files with 3 additions and 4 deletions

View File

@ -20,8 +20,8 @@ np.random.seed(seed)
random.seed(seed) random.seed(seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu") device = torch.device("cpu")
# device = torch.device("mps") # device = torch.device("mps")
num_epochs = 1 num_epochs = 1
@ -128,8 +128,8 @@ for epoch in range(epochs):
all = range(0, sample.shape[0]) all = range(0, sample.shape[0])
# ratio_max = abs / mean # ratio_max = abs / mean
# ratio_nor = (max - abs) / max # ratio_nor = (max - abs) / max
ratio_nor = torch.pow(abs / mean, 4)
# ratio_nor[all, max_index] = ratio_max[all, max_index].clone() # ratio_nor[all, max_index] = ratio_max[all, max_index].clone()
ratio_nor = torch.pow(abs / mean, 4)
ratio_nor = torch.where(torch.isnan(ratio_nor), 1.0, ratio_nor) ratio_nor = torch.where(torch.isnan(ratio_nor), 1.0, ratio_nor)
label = sample * ratio_nor label = sample * ratio_nor
@ -197,7 +197,6 @@ with torch.no_grad():
labels = labels.to(device) labels = labels.to(device)
outputs = model(images) outputs = model(images)
# max returns (value ,index)
_, predicted = torch.max(outputs.data, 1) _, predicted = torch.max(outputs.data, 1)
n_samples += labels.size(0) n_samples += labels.size(0)
n_correct += (predicted == labels).sum().item() n_correct += (predicted == labels).sum().item()