Refine minist unsuper.
This commit is contained in:
		
							parent
							
								
									5b2cd4da61
								
							
						
					
					
						commit
						59449df047
					
				|  | @ -20,8 +20,8 @@ np.random.seed(seed) | |||
| random.seed(seed) | ||||
| 
 | ||||
| 
 | ||||
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||||
| # device = torch.device("cpu") | ||||
| # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||||
| device = torch.device("cpu") | ||||
| # device = torch.device("mps") | ||||
| 
 | ||||
| num_epochs = 1 | ||||
|  | @ -128,8 +128,8 @@ for epoch in range(epochs): | |||
|         all = range(0, sample.shape[0]) | ||||
|         # ratio_max = abs / mean | ||||
|         # ratio_nor = (max - abs) / max | ||||
|         ratio_nor = torch.pow(abs / mean, 4) | ||||
|         # ratio_nor[all, max_index] = ratio_max[all, max_index].clone() | ||||
|         ratio_nor = torch.pow(abs / mean, 4) | ||||
|         ratio_nor = torch.where(torch.isnan(ratio_nor), 1.0, ratio_nor) | ||||
|         label = sample * ratio_nor | ||||
| 
 | ||||
|  | @ -197,7 +197,6 @@ with torch.no_grad(): | |||
|         labels = labels.to(device) | ||||
|         outputs = model(images) | ||||
| 
 | ||||
|         # max returns (value ,index) | ||||
|         _, predicted = torch.max(outputs.data, 1) | ||||
|         n_samples += labels.size(0) | ||||
|         n_correct += (predicted == labels).sum().item() | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue