diff --git a/unsuper/minist.py b/unsuper/minist.py index 024fa69..33e9158 100644 --- a/unsuper/minist.py +++ b/unsuper/minist.py @@ -20,8 +20,8 @@ np.random.seed(seed) random.seed(seed) -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -# device = torch.device("cpu") +# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device = torch.device("cpu") # device = torch.device("mps") num_epochs = 1 @@ -128,8 +128,8 @@ for epoch in range(epochs): all = range(0, sample.shape[0]) # ratio_max = abs / mean # ratio_nor = (max - abs) / max - ratio_nor = torch.pow(abs / mean, 4) # ratio_nor[all, max_index] = ratio_max[all, max_index].clone() + ratio_nor = torch.pow(abs / mean, 4) ratio_nor = torch.where(torch.isnan(ratio_nor), 1.0, ratio_nor) label = sample * ratio_nor @@ -197,7 +197,6 @@ with torch.no_grad(): labels = labels.to(device) outputs = model(images) - # max returns (value ,index) _, predicted = torch.max(outputs.data, 1) n_samples += labels.size(0) n_correct += (predicted == labels).sum().item()