import os import sys import torch import torch.nn as nn import torch.nn.functional as F # Add this line import torchvision import torchvision.transforms as transforms import numpy as np import random sys.path.append("..") from tools import show seed = 42 torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # device = torch.device("cpu") # device = torch.device("mps") num_epochs = 1 batch_size = 64 transform = transforms.Compose([transforms.ToTensor()]) train_dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform) test_dataset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False) class ConvNet(nn.Module): def __init__(self): super(ConvNet, self).__init__() self.conv1 = nn.Conv2d(1, 8, 5, 1, 0) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(8, 1, 5, 1, 0) self.fc1 = nn.Linear(1 * 4 * 4, 10) def forward(self, x): x = self.forward_unsuper(x) x = self.pool(x) x = self.pool(self.conv2(x)) x = x.view(x.shape[0], -1) x = self.fc1(x) return x def normal_conv1_weight(self): weight = self.conv1.weight.reshape(self.conv1.weight.shape[0], -1) weight = weight.permute(1, 0) mean = torch.mean(weight, dim=0) weight = weight - mean sum = torch.sum(torch.abs(weight), dim=0) weight = weight / sum weight = weight.permute(1, 0) weight = weight.reshape(self.conv1.weight.shape) return weight def forward_unsuper(self, x): x = torch.conv2d(x, self.normal_conv1_weight(), stride=1) return x def printFector(self, x, label, dir=""): show.DumpTensorToImage(x.view(-1, x.shape[2], x.shape[3]), dir + "/input_image.png", Contrast=[0, 1.0]) w = self.normal_conv1_weight() x = torch.conv2d(x, w) show.DumpTensorToImage(w.view(-1, w.shape[2], w.shape[3]), dir + "/conv1_weight.png") show.DumpTensorToImage(x.view(-1, x.shape[2], x.shape[3]), dir + "/conv1_output.png") x = self.pool(x) x = self.conv2(x) w = self.conv2.weight show.DumpTensorToImage(w.view(-1, w.shape[2], w.shape[3]).cpu(), dir + "/conv2_weight.png") show.DumpTensorToImage(x.view(-1, x.shape[2], x.shape[3]).cpu(), dir + "/conv2_output.png") x = self.pool(x) show.DumpTensorToImage(x.view(-1, x.shape[2], x.shape[3]).cpu(), dir + "/pool_output.png") pool_shape = x.shape x = x.view(x.shape[0], -1) x = self.fc1(x) show.DumpTensorToImage( self.fc1.weight.view(-1, pool_shape[2], pool_shape[3]), dir + "/fc_weight.png", Contrast=[-1.0, 1.0] ) show.DumpTensorToImage(x.view(-1).cpu(), dir + "/fc_output.png") criterion = nn.CrossEntropyLoss() loss = criterion(x, label) loss.backward() if self.conv1.weight.requires_grad: w = self.conv1.weight.grad show.DumpTensorToImage(w.view(-1, w.shape[2], w.shape[3]).cpu(), dir + "/conv1_weight_grad.png") if self.conv2.weight.requires_grad: w = self.conv2.weight.grad show.DumpTensorToImage(w.view(-1, w.shape[2], w.shape[3]), dir + "/conv2_weight_grad.png") if self.fc1.weight.requires_grad: show.DumpTensorToImage( self.fc1.weight.grad.view(-1, pool_shape[2], pool_shape[3]), dir + "/fc_weight_grad.png" ) model = ConvNet().to(device) model.train() # Train the model unsuper epochs = 3 n_total_steps = len(train_loader) for epoch in range(epochs): for i, (images, labels) in enumerate(train_loader): images = images.to(device) # images = torch.ones((1, 1, 5, 5), device=device) # type = random.randint(0, 3) # if type == 0: # rand = random.randint(0, 4) # images[:, :, rand, :] = images[:, :, rand, :] * 0.5 # if type == 1: # rand = random.randint(0, 4) # images[:, :, :, rand] = images[:, :, :, rand] * 0.5 # if type == 2: # images[:, :, 0, 0] = images[:, :, 0, 0] * 0.5 # images[:, :, 1, 1] = images[:, :, 1, 1] * 0.5 # images[:, :, 2, 2] = images[:, :, 2, 2] * 0.5 # images[:, :, 3, 3] = images[:, :, 3, 3] * 0.5 # images[:, :, 4, 4] = images[:, :, 4, 4] * 0.5 # if type == 3: # randx = random.randint(1, 3) # randy = random.randint(1, 3) # images[:, :, randx, randy] = images[:, :, randx, randy] * 0.5 # images[:, :, randx, randy + 1] = images[:, :, randx, randy + 1] * 0.5 # images[:, :, randx, randy - 1] = images[:, :, randx, randy - 1] * 0.5 # images[:, :, randx + 1, randy] = images[:, :, randx + 1, randy] * 0.5 # images[:, :, randx - 1, randy] = images[:, :, randx - 1, randy] * 0.5 outputs = model.forward_unsuper(images) outputs = outputs.permute(0, 2, 3, 1) # 64 8 24 24 -> 64 24 24 8 sample = outputs.reshape(-1, outputs.shape[3]) # -> 36864 8 # sample = outputs.reshape(-1, 8,24*24) # -> 36864 8 # sample = torch.mean(sample,dim=2) # -> 36864 8 abs = torch.abs(sample).detach() max, max_index = torch.max(abs, dim=1) mean = torch.mean(abs, dim=1) mean = torch.expand_copy(mean.reshape(-1, 1), abs.shape) max = torch.expand_copy(max.reshape(-1, 1), abs.shape) e = torch.sum(torch.pow(abs - mean, 2), dim=1) e = torch.expand_copy(e.reshape(-1, 1), abs.shape) e = 1 / e e = torch.where(torch.isinf(e), 1.0, e) e = torch.pow(e, 0.5) ratio = abs / mean * e # ratio = torch.pow(abs / mean, e ) ratio = torch.where(torch.isnan(ratio), 0.0, ratio) label = ratio * abs label_mean = torch.expand_copy(torch.mean(label, dim=1).reshape(-1, 1), abs.shape) label = label - label_mean + mean sample = torch.abs(sample) loss = F.l1_loss(sample[abs > 0], label[abs > 0]) model.conv1.weight.grad = None loss.backward() # if epoch >= (epochs - 1): # continue model.conv1.weight.data = model.conv1.weight.data - model.conv1.weight.grad * 0.01 model.conv1.weight.data = model.normal_conv1_weight() if (i + 1) % 100 == 0: print(f"Epoch [{epoch+1}/{epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.8f}") show.DumpTensorToImage(images.view(-1, images.shape[2], images.shape[3]), "input_image.png", Contrast=[0, 1.0]) g = model.conv1.weight.grad show.DumpTensorToImage(g.view(-1, g.shape[2], g.shape[3]).cpu(), "conv1_weight_grad.png", Value2Log=True) w = model.conv1.weight.data show.DumpTensorToImage(w.view(-1, w.shape[2], w.shape[3]), "conv1_weight_update.png", Value2Log=True) # model.conv1.weight.data = torch.rand(model.conv1.weight.data.shape, device=device) # model.conv2.weight.data = torch.ones(model.conv2.weight.data.shape, device=device) # Train the model model.conv1.weight.requires_grad = False model.conv2.weight.requires_grad = True model.fc1.weight.requires_grad = True criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.2) n_total_steps = len(train_loader) for epoch in range(num_epochs): for i, (images, labels) in enumerate(train_loader): images = images.to(device) labels = labels.to(device) outputs = model(images) loss = criterion(outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step() if (i + 1) % 100 == 0: print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.4f}") print("Finished Training") test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False) test_loader = iter(test_loader) images, labels = next(test_loader) images = images.to(device) labels = labels.to(device) model.printFector(images, labels, "dump1") images, labels = next(test_loader) images = images.to(device) labels = labels.to(device) model.printFector(images, labels, "dump2") # Test the model with torch.no_grad(): n_correct = 0 n_samples = 0 for images, labels in test_loader: images = images.to(device) labels = labels.to(device) outputs = model(images) _, predicted = torch.max(outputs.data, 1) n_samples += labels.size(0) n_correct += (predicted == labels).sum().item() acc = 100.0 * n_correct / n_samples print(f"Accuracy of the network on the 10000 test images: {acc} %")