refine code
This commit is contained in:
		
							parent
							
								
									d50cb798b6
								
							
						
					
					
						commit
						b860d794a6
					
				|  | @ -15,12 +15,15 @@ learning_rate = 0.001 | ||||||
| 
 | 
 | ||||||
| # Dataset has PILImage images of range [0, 1]. | # Dataset has PILImage images of range [0, 1]. | ||||||
| # We transform them to Tensors of normalized range [-1, 1] | # We transform them to Tensors of normalized range [-1, 1] | ||||||
| transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) | transform = transforms.Compose([transforms.ToTensor()]) | ||||||
| 
 | 
 | ||||||
| # CIFAR10: 60000 32x32 color images in 10 classes, with 6000 images per class | # CIFAR10: 60000 32x32 color images in 10 classes, with 6000 images per class | ||||||
| train_dataset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform) | # train_dataset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform) | ||||||
|  | train_dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform) | ||||||
| 
 | 
 | ||||||
| test_dataset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform) | 
 | ||||||
|  | # test_dataset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform) | ||||||
|  | test_dataset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform) | ||||||
| 
 | 
 | ||||||
| train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) | ||||||
| 
 | 
 | ||||||
|  | @ -30,22 +33,22 @@ test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, s | ||||||
| class ConvNet(nn.Module): | class ConvNet(nn.Module): | ||||||
|     def __init__(self): |     def __init__(self): | ||||||
|         super(ConvNet, self).__init__() |         super(ConvNet, self).__init__() | ||||||
|         self.conv1 = nn.Conv2d(3, 6, 5) |         self.conv1 = nn.Conv2d(1, 6, 3, 1, 1) | ||||||
|         self.pool = nn.MaxPool2d(2, 2) |         self.pool = nn.MaxPool2d(2, 2) | ||||||
|         self.conv2 = nn.Conv2d(6, 16, 5) |         self.conv2 = nn.Conv2d(6, 16, 5) | ||||||
|         self.fc1 = nn.Linear(16 * 5 * 5, 120) |         self.fc1 = nn.Linear(16 * 5 * 5, 10) | ||||||
|         self.fc2 = nn.Linear(120, 84) |         # self.fc2 = nn.Linear(120, 84) | ||||||
|         self.fc3 = nn.Linear(84, 10) |         # self.fc3 = nn.Linear(84, 10) | ||||||
| 
 | 
 | ||||||
|     def forward(self, x): |     def forward(self, x): | ||||||
|         x = self.pool(F.relu(self.conv1(x))) |         x = self.pool(F.relu(self.conv1(x))) | ||||||
|         x = self.pool(F.relu(self.conv2(x))) |         x = self.pool(F.relu(self.conv2(x))) | ||||||
|         x = x.view(-1, 16 * 5 * 5) |         x = x.view(-1, 16 * 5 * 5) | ||||||
|         # x = F.relu(self.fc1(x)) |         # x = F.relu(self.fc1(x)) | ||||||
|         x = self.fc1(x) |  | ||||||
|         # x = F.relu(self.fc2(x)) |         # x = F.relu(self.fc2(x)) | ||||||
|         x = self.fc2(x) |         # x = self.fc3(x) | ||||||
|         x = self.fc3(x) | 
 | ||||||
|  |         x = self.fc1(x) | ||||||
|         return x |         return x | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -54,6 +57,7 @@ model = ConvNet().to(device) | ||||||
| criterion = nn.CrossEntropyLoss() | criterion = nn.CrossEntropyLoss() | ||||||
| optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) | optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| # Train the model | # Train the model | ||||||
| n_total_steps = len(train_loader) | n_total_steps = len(train_loader) | ||||||
| for epoch in range(num_epochs): | for epoch in range(num_epochs): | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue