From 59449df047882be52f5f6fcbe319cb2a76226ddd Mon Sep 17 00:00:00 2001 From: Colin Date: Mon, 28 Oct 2024 18:52:28 +0800 Subject: [PATCH] Refine minist unsuper. --- unsuper/minist.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/unsuper/minist.py b/unsuper/minist.py index 024fa69..33e9158 100644 --- a/unsuper/minist.py +++ b/unsuper/minist.py @@ -20,8 +20,8 @@ np.random.seed(seed) random.seed(seed) -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -# device = torch.device("cpu") +# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device = torch.device("cpu") # device = torch.device("mps") num_epochs = 1 @@ -128,8 +128,8 @@ for epoch in range(epochs): all = range(0, sample.shape[0]) # ratio_max = abs / mean # ratio_nor = (max - abs) / max - ratio_nor = torch.pow(abs / mean, 4) # ratio_nor[all, max_index] = ratio_max[all, max_index].clone() + ratio_nor = torch.pow(abs / mean, 4) ratio_nor = torch.where(torch.isnan(ratio_nor), 1.0, ratio_nor) label = sample * ratio_nor @@ -197,7 +197,6 @@ with torch.no_grad(): labels = labels.to(device) outputs = model(images) - # max returns (value ,index) _, predicted = torch.max(outputs.data, 1) n_samples += labels.size(0) n_correct += (predicted == labels).sum().item()