diff --git a/unsuper/minist.py b/unsuper/minist.py index 08d80ca..19a13cb 100644 --- a/unsuper/minist.py +++ b/unsuper/minist.py @@ -15,12 +15,15 @@ learning_rate = 0.001 # Dataset has PILImage images of range [0, 1]. # We transform them to Tensors of normalized range [-1, 1] -transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) +transform = transforms.Compose([transforms.ToTensor()]) # CIFAR10: 60000 32x32 color images in 10 classes, with 6000 images per class -train_dataset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform) +# train_dataset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform) +train_dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform) -test_dataset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform) + +# test_dataset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform) +test_dataset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) @@ -30,22 +33,22 @@ test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, s class ConvNet(nn.Module): def __init__(self): super(ConvNet, self).__init__() - self.conv1 = nn.Conv2d(3, 6, 5) + self.conv1 = nn.Conv2d(1, 6, 3, 1, 1) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) - self.fc1 = nn.Linear(16 * 5 * 5, 120) - self.fc2 = nn.Linear(120, 84) - self.fc3 = nn.Linear(84, 10) + self.fc1 = nn.Linear(16 * 5 * 5, 10) + # self.fc2 = nn.Linear(120, 84) + # self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) # x = F.relu(self.fc1(x)) - x = self.fc1(x) # x = F.relu(self.fc2(x)) - x = self.fc2(x) - x = self.fc3(x) + # x = self.fc3(x) + + x = self.fc1(x) return x @@ -54,6 +57,7 @@ model = ConvNet().to(device) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) + # Train the model n_total_steps = len(train_loader) for epoch in range(num_epochs):