diff --git a/unsuper/dump1/conv1_output.png b/unsuper/dump1/conv1_output.png index e41e881..034f0a7 100644 Binary files a/unsuper/dump1/conv1_output.png and b/unsuper/dump1/conv1_output.png differ diff --git a/unsuper/dump1/conv1_weight.png b/unsuper/dump1/conv1_weight.png index e8681f8..4ee741a 100644 Binary files a/unsuper/dump1/conv1_weight.png and b/unsuper/dump1/conv1_weight.png differ diff --git a/unsuper/dump1/conv1_weight_grad.png b/unsuper/dump1/conv1_weight_grad.png index 3e65f94..2ef2fe4 100644 Binary files a/unsuper/dump1/conv1_weight_grad.png and b/unsuper/dump1/conv1_weight_grad.png differ diff --git a/unsuper/dump1/conv2_output.png b/unsuper/dump1/conv2_output.png index c720b6e..fb9737a 100644 Binary files a/unsuper/dump1/conv2_output.png and b/unsuper/dump1/conv2_output.png differ diff --git a/unsuper/dump1/conv2_weight.png b/unsuper/dump1/conv2_weight.png index 53f2703..f9bd915 100644 Binary files a/unsuper/dump1/conv2_weight.png and b/unsuper/dump1/conv2_weight.png differ diff --git a/unsuper/dump1/conv2_weight_grad.png b/unsuper/dump1/conv2_weight_grad.png index c1a5f4d..63e0b78 100644 Binary files a/unsuper/dump1/conv2_weight_grad.png and b/unsuper/dump1/conv2_weight_grad.png differ diff --git a/unsuper/dump1/fc_output.png b/unsuper/dump1/fc_output.png index bdb90d2..8a0200f 100644 Binary files a/unsuper/dump1/fc_output.png and b/unsuper/dump1/fc_output.png differ diff --git a/unsuper/dump1/fc_weight.png b/unsuper/dump1/fc_weight.png index e6f3b9c..3d66203 100644 Binary files a/unsuper/dump1/fc_weight.png and b/unsuper/dump1/fc_weight.png differ diff --git a/unsuper/dump1/fc_weight_grad.png b/unsuper/dump1/fc_weight_grad.png index 645ea9e..f86c1dc 100644 Binary files a/unsuper/dump1/fc_weight_grad.png and b/unsuper/dump1/fc_weight_grad.png differ diff --git a/unsuper/dump1/pool_output.png b/unsuper/dump1/pool_output.png index 19bb070..49eb371 100644 Binary files a/unsuper/dump1/pool_output.png and b/unsuper/dump1/pool_output.png differ diff --git a/unsuper/dump2/conv1_output.png b/unsuper/dump2/conv1_output.png index 052c356..d13915c 100644 Binary files a/unsuper/dump2/conv1_output.png and b/unsuper/dump2/conv1_output.png differ diff --git a/unsuper/dump2/conv1_weight.png b/unsuper/dump2/conv1_weight.png index e8681f8..4ee741a 100644 Binary files a/unsuper/dump2/conv1_weight.png and b/unsuper/dump2/conv1_weight.png differ diff --git a/unsuper/dump2/conv1_weight_grad.png b/unsuper/dump2/conv1_weight_grad.png index 3e65f94..2ef2fe4 100644 Binary files a/unsuper/dump2/conv1_weight_grad.png and b/unsuper/dump2/conv1_weight_grad.png differ diff --git a/unsuper/dump2/conv2_output.png b/unsuper/dump2/conv2_output.png index 421d456..58df023 100644 Binary files a/unsuper/dump2/conv2_output.png and b/unsuper/dump2/conv2_output.png differ diff --git a/unsuper/dump2/conv2_weight.png b/unsuper/dump2/conv2_weight.png index 53f2703..f9bd915 100644 Binary files a/unsuper/dump2/conv2_weight.png and b/unsuper/dump2/conv2_weight.png differ diff --git a/unsuper/dump2/conv2_weight_grad.png b/unsuper/dump2/conv2_weight_grad.png index f0f5e4a..0a9ae3a 100644 Binary files a/unsuper/dump2/conv2_weight_grad.png and b/unsuper/dump2/conv2_weight_grad.png differ diff --git a/unsuper/dump2/fc_output.png b/unsuper/dump2/fc_output.png index bdb90d2..d078719 100644 Binary files a/unsuper/dump2/fc_output.png and b/unsuper/dump2/fc_output.png differ diff --git a/unsuper/dump2/fc_weight.png b/unsuper/dump2/fc_weight.png index e6f3b9c..3d66203 100644 Binary files a/unsuper/dump2/fc_weight.png and b/unsuper/dump2/fc_weight.png differ diff --git a/unsuper/dump2/fc_weight_grad.png b/unsuper/dump2/fc_weight_grad.png index 645ea9e..f7041e8 100644 Binary files a/unsuper/dump2/fc_weight_grad.png and b/unsuper/dump2/fc_weight_grad.png differ diff --git a/unsuper/dump2/pool_output.png b/unsuper/dump2/pool_output.png index 19bb070..c08cd8d 100644 Binary files a/unsuper/dump2/pool_output.png and b/unsuper/dump2/pool_output.png differ diff --git a/unsuper/minist.py b/unsuper/minist.py index 3a98061..8a1dec2 100644 --- a/unsuper/minist.py +++ b/unsuper/minist.py @@ -16,10 +16,8 @@ torch.cuda.manual_seed_all(seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # device = torch.device("mps") -# Hyper-parameters num_epochs = 1 batch_size = 64 -learning_rate = 0.2 transform = transforms.Compose([transforms.ToTensor()]) @@ -36,93 +34,132 @@ class ConvNet(nn.Module): self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(8, 1, 5, 1, 0) self.fc1 = nn.Linear(1 * 4 * 4, 10) - # self.fc2 = nn.Linear(120, 84) - # self.fc3 = nn.Linear(84, 10) def forward(self, x): - x = self.pool(F.relu(self.conv1(x))) - x = self.pool(F.relu(self.conv2(x))) + x = self.pool(self.conv1(x)) + x = self.pool(self.conv2(x)) x = x.view(x.shape[0], -1) - # x = F.relu(self.fc1(x)) - # x = F.relu(self.fc2(x)) - # x = self.fc3(x) - x = self.fc1(x) return x - def printFector(self, x, label): - show.DumpTensorToImage(x.view(-1, x.shape[2], x.shape[3]), "input_image.png", Contrast=[0, 1.0]) + def forward_unsuper(self, x): + x = self.pool(self.conv1(x)) + return x + + def forward_finetune(self, x): + x = self.pool(self.conv1(x)) + x = self.pool(self.conv2(x)) + x = x.view(x.shape[0], -1) + x = self.fc1(x) + return x + + def printFector(self, x, label, dir=""): + show.DumpTensorToImage(x.view(-1, x.shape[2], x.shape[3]), dir + "/input_image.png", Contrast=[0, 1.0]) # show.DumpTensorToLog(x, "input_image.log") x = self.conv1(x) w = self.conv1.weight - show.DumpTensorToImage(w.view(-1, w.shape[2], w.shape[3]), "conv1_weight.png", Contrast=[-1.0, 1.0]) + show.DumpTensorToImage(w.view(-1, w.shape[2], w.shape[3]), dir + "/conv1_weight.png", Contrast=[-1.0, 1.0]) # show.DumpTensorToLog(w, "conv1_weight.log") - show.DumpTensorToImage(x.view(-1, x.shape[2], x.shape[3]), "conv1_output.png", Contrast=[-1.0, 1.0]) + show.DumpTensorToImage(x.view(-1, x.shape[2], x.shape[3]), dir + "/conv1_output.png", Contrast=[-1.0, 1.0]) # show.DumpTensorToLog(x, "conv1_output.png") x = self.pool(F.relu(x)) x = self.conv2(x) w = self.conv2.weight - show.DumpTensorToImage(w.view(-1, w.shape[2], w.shape[3]).cpu(), "conv2_weight.png", Contrast=[-1.0, 1.0]) + show.DumpTensorToImage( + w.view(-1, w.shape[2], w.shape[3]).cpu(), dir + "/conv2_weight.png", Contrast=[-1.0, 1.0] + ) - show.DumpTensorToImage(x.view(-1, x.shape[2], x.shape[3]).cpu(), "conv2_output.png", Contrast=[-1.0, 1.0]) + show.DumpTensorToImage( + x.view(-1, x.shape[2], x.shape[3]).cpu(), dir + "/conv2_output.png", Contrast=[-1.0, 1.0] + ) x = self.pool(F.relu(x)) - show.DumpTensorToImage(x.view(-1, x.shape[2], x.shape[3]).cpu(), "pool_output.png", Contrast=[-1.0, 1.0]) + show.DumpTensorToImage(x.view(-1, x.shape[2], x.shape[3]).cpu(), dir + "/pool_output.png", Contrast=[-1.0, 1.0]) pool_shape = x.shape x = x.view(x.shape[0], -1) x = self.fc1(x) show.DumpTensorToImage( - self.fc1.weight.view(-1, pool_shape[2], pool_shape[3]), "fc_weight.png", Contrast=[-1.0, 1.0] + self.fc1.weight.view(-1, pool_shape[2], pool_shape[3]), dir + "/fc_weight.png", Contrast=[-1.0, 1.0] ) - show.DumpTensorToImage(x.view(-1).cpu(), "fc_output.png") + show.DumpTensorToImage(x.view(-1).cpu(), dir + "/fc_output.png") criterion = nn.CrossEntropyLoss() loss = criterion(x, label) - optimizer.zero_grad() loss.backward() - w = self.conv1.weight.grad - show.DumpTensorToImage(w.view(-1, w.shape[2], w.shape[3]).cpu(), "conv1_weight_grad.png") - w = self.conv2.weight.grad - show.DumpTensorToImage(w.view(-1, w.shape[2], w.shape[3]), "conv2_weight_grad.png") - show.DumpTensorToImage(self.fc1.weight.grad.view(-1, pool_shape[2], pool_shape[3]), "fc_weight_grad.png") + if self.conv1.weight.requires_grad: + w = self.conv1.weight.grad + show.DumpTensorToImage(w.view(-1, w.shape[2], w.shape[3]).cpu(), dir + "/conv1_weight_grad.png") + if self.conv2.weight.requires_grad: + w = self.conv2.weight.grad + show.DumpTensorToImage(w.view(-1, w.shape[2], w.shape[3]), dir + "/conv2_weight_grad.png") + if self.fc1.weight.requires_grad: + show.DumpTensorToImage( + self.fc1.weight.grad.view(-1, pool_shape[2], pool_shape[3]), dir + "/fc_weight_grad.png" + ) model = ConvNet().to(device) +model.train() -criterion = nn.CrossEntropyLoss() -optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) - +# Train the model unsuper +epochs = 10 +model.conv1.weight.requires_grad = True +model.conv2.weight.requires_grad = False +model.fc1.weight.requires_grad = False +optimizer_unsuper = torch.optim.SGD(model.parameters(), lr=0.1) +n_total_steps = len(train_loader) +for epoch in range(epochs): + for i, (images, labels) in enumerate(train_loader): + images = images.to(device) + outputs = model.forward_unsuper(images) + sample = outputs.view(outputs.shape[0], -1) + sample_mean = torch.mean(sample, dim=1, keepdim=True) + diff_mean = torch.mean(torch.abs(sample - sample_mean), dim=1, keepdim=True) + diff_ratio = (sample - sample_mean) / diff_mean + diff_ratio_mean = torch.mean(diff_ratio * diff_ratio, dim=1) + label = diff_ratio_mean * 0.5 + loss = F.l1_loss(diff_ratio_mean, label) + optimizer_unsuper.zero_grad() + loss.backward() + optimizer_unsuper.step() + if (i + 1) % 100 == 0: + print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.8f}") # Train the model +model.conv1.weight.requires_grad = False +model.conv2.weight.requires_grad = True +model.fc1.weight.requires_grad = True +criterion = nn.CrossEntropyLoss() +optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.6) n_total_steps = len(train_loader) for epoch in range(num_epochs): for i, (images, labels) in enumerate(train_loader): images = images.to(device) labels = labels.to(device) - - # Forward pass - outputs = model(images) + outputs = model.forward_finetune(images) loss = criterion(outputs, labels) - - # Backward and optimize optimizer.zero_grad() loss.backward() optimizer.step() - if (i + 1) % 100 == 0: print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.4f}") -test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False) -for images, labels in test_loader: - images = images.to(device) - labels = labels.to(device) - model.printFector(images, labels) - break - print("Finished Training") +test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False) +test_loader = iter(test_loader) +images, labels = next(test_loader) +images = images.to(device) +labels = labels.to(device) +model.printFector(images, labels, "dump1") + +images, labels = next(test_loader) +images = images.to(device) +labels = labels.to(device) +model.printFector(images, labels, "dump2") + # Test the model with torch.no_grad(): n_correct = 0