Refine minist unsuper.
This commit is contained in:
		
							parent
							
								
									5b2cd4da61
								
							
						
					
					
						commit
						59449df047
					
				| 
						 | 
					@ -20,8 +20,8 @@ np.random.seed(seed)
 | 
				
			||||||
random.seed(seed)
 | 
					random.seed(seed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 | 
					# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 | 
				
			||||||
# device = torch.device("cpu")
 | 
					device = torch.device("cpu")
 | 
				
			||||||
# device = torch.device("mps")
 | 
					# device = torch.device("mps")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
num_epochs = 1
 | 
					num_epochs = 1
 | 
				
			||||||
| 
						 | 
					@ -128,8 +128,8 @@ for epoch in range(epochs):
 | 
				
			||||||
        all = range(0, sample.shape[0])
 | 
					        all = range(0, sample.shape[0])
 | 
				
			||||||
        # ratio_max = abs / mean
 | 
					        # ratio_max = abs / mean
 | 
				
			||||||
        # ratio_nor = (max - abs) / max
 | 
					        # ratio_nor = (max - abs) / max
 | 
				
			||||||
        ratio_nor = torch.pow(abs / mean, 4)
 | 
					 | 
				
			||||||
        # ratio_nor[all, max_index] = ratio_max[all, max_index].clone()
 | 
					        # ratio_nor[all, max_index] = ratio_max[all, max_index].clone()
 | 
				
			||||||
 | 
					        ratio_nor = torch.pow(abs / mean, 4)
 | 
				
			||||||
        ratio_nor = torch.where(torch.isnan(ratio_nor), 1.0, ratio_nor)
 | 
					        ratio_nor = torch.where(torch.isnan(ratio_nor), 1.0, ratio_nor)
 | 
				
			||||||
        label = sample * ratio_nor
 | 
					        label = sample * ratio_nor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -197,7 +197,6 @@ with torch.no_grad():
 | 
				
			||||||
        labels = labels.to(device)
 | 
					        labels = labels.to(device)
 | 
				
			||||||
        outputs = model(images)
 | 
					        outputs = model(images)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # max returns (value ,index)
 | 
					 | 
				
			||||||
        _, predicted = torch.max(outputs.data, 1)
 | 
					        _, predicted = torch.max(outputs.data, 1)
 | 
				
			||||||
        n_samples += labels.size(0)
 | 
					        n_samples += labels.size(0)
 | 
				
			||||||
        n_correct += (predicted == labels).sum().item()
 | 
					        n_correct += (predicted == labels).sum().item()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue