Update unsuper.
This commit is contained in:
		
							parent
							
								
									45d5701835
								
							
						
					
					
						commit
						22464e7724
					
				| 
						 | 
					@ -115,43 +115,27 @@ for epoch in range(epochs):
 | 
				
			||||||
        images = images.to(device)
 | 
					        images = images.to(device)
 | 
				
			||||||
        outputs = model.forward_unsuper(images)
 | 
					        outputs = model.forward_unsuper(images)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # outputs = outputs.permute(0, 2, 3, 1)  # 64 8 24 24 -> 64 24 24 8
 | 
					 | 
				
			||||||
        # sample = outputs.reshape(-1, outputs.shape[3])  # -> 36864 8
 | 
					 | 
				
			||||||
        # abs = torch.abs(sample)
 | 
					 | 
				
			||||||
        # max, max_index = torch.max(abs, dim=1)
 | 
					 | 
				
			||||||
        # min, min_index = torch.min(abs, dim=1)
 | 
					 | 
				
			||||||
        # label = sample * 0.9
 | 
					 | 
				
			||||||
        # all = range(0, label.shape[0])
 | 
					 | 
				
			||||||
        # label[all, max_index] = label[all, max_index]*1.1
 | 
					 | 
				
			||||||
        # loss = F.l1_loss(sample, label)
 | 
					 | 
				
			||||||
        # model.conv1.weight.grad = None
 | 
					 | 
				
			||||||
        # loss.backward()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        outputs = outputs.permute(0, 2, 3, 1)  # 64 8 24 24 -> 64 24 24 8
 | 
					        outputs = outputs.permute(0, 2, 3, 1)  # 64 8 24 24 -> 64 24 24 8
 | 
				
			||||||
        sample = outputs.reshape(outputs.shape[0], -1, outputs.shape[3])  # -> 64 24x24 8
 | 
					        sample = outputs.reshape(-1, outputs.shape[3])  # -> 36864 8
 | 
				
			||||||
        abs = torch.abs(sample)
 | 
					        abs = torch.abs(sample)
 | 
				
			||||||
        sum = torch.sum(abs, dim=1, keepdim=False)
 | 
					        max, max_index = torch.max(abs, dim=1)
 | 
				
			||||||
        max, max_index = torch.max(sum, dim=1)
 | 
					 | 
				
			||||||
        label = sample * 0.9
 | 
					        label = sample * 0.9
 | 
				
			||||||
        all = range(0, label.shape[0])
 | 
					        all = range(0, label.shape[0])
 | 
				
			||||||
        all_wh = range(0, 24 * 24)
 | 
					        label[all, max_index] = label[all, max_index] * 1.1
 | 
				
			||||||
        label[all, :, max_index] = label[all, :, max_index] * 1.1
 | 
					 | 
				
			||||||
        loss = F.l1_loss(sample, label)
 | 
					        loss = F.l1_loss(sample, label)
 | 
				
			||||||
        model.conv1.weight.grad = None
 | 
					        model.conv1.weight.grad = None
 | 
				
			||||||
        loss.backward()
 | 
					        loss.backward()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # show.DumpTensorToImage(images.view(-1, images.shape[2], images.shape[3]), "input_image.png", Contrast=[0, 1.0])
 | 
					        model.conv1.weight.data = model.conv1.weight.data - model.conv1.weight.grad * 100
 | 
				
			||||||
        # w = model.conv1.weight.data
 | 
					 | 
				
			||||||
        # show.DumpTensorToImage(w.view(-1, w.shape[2], w.shape[3]), "conv1_weight.png", Contrast=[-1.0, 1.0])
 | 
					 | 
				
			||||||
        # w = model.conv1.weight.grad
 | 
					 | 
				
			||||||
        # show.DumpTensorToImage(w.view(-1, w.shape[2], w.shape[3]).cpu(), "conv1_weight_grad.png")
 | 
					 | 
				
			||||||
        model.conv1.weight.data = model.conv1.weight.data - model.conv1.weight.grad * 1000
 | 
					 | 
				
			||||||
        # w = model.conv1.weight.data
 | 
					 | 
				
			||||||
        # show.DumpTensorToImage(w.view(-1, w.shape[2], w.shape[3]), "conv1_weight_update.png", Contrast=[-1.0, 1.0])
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if (i + 1) % 100 == 0:
 | 
					        if (i + 1) % 100 == 0:
 | 
				
			||||||
            print(f"Epoch [{epoch+1}/{epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.8f}")
 | 
					            print(f"Epoch [{epoch+1}/{epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.8f}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					w = model.conv1.weight.grad
 | 
				
			||||||
 | 
					show.DumpTensorToImage(w.view(-1, w.shape[2], w.shape[3]).cpu(), "conv1_weight_grad.png")
 | 
				
			||||||
 | 
					w = model.conv1.weight.data
 | 
				
			||||||
 | 
					show.DumpTensorToImage(w.view(-1, w.shape[2], w.shape[3]), "conv1_weight_update.png", Contrast=[-1.0, 1.0])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Train the model
 | 
					# Train the model
 | 
				
			||||||
model.conv1.weight.requires_grad = False
 | 
					model.conv1.weight.requires_grad = False
 | 
				
			||||||
model.conv2.weight.requires_grad = True
 | 
					model.conv2.weight.requires_grad = True
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue