from torch.utils.data import Dataset, DataLoader import numpy as np import torchvision.models as models from torchvision import datasets, transforms import torchvision import torch.optim as optim import torch.nn.functional as F import torch.nn as nn import torch import os import utils as utils def train(model, train_loader, optimizer, epoch=0): model.train() batchsize = int(train_loader.sampler.num_samples / train_loader.batch_size / 5)+1 for batch_idx, (data, target) in enumerate(train_loader): data = utils.SetDevice(data) target = utils.SetDevice(target) optimizer.zero_grad() output = model(data) loss = F.nll_loss(output, target) loss.backward() optimizer.step() if batch_idx % batchsize == 0 and batch_idx > 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}' .format(epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item())) def test(model, test_loader): with torch.no_grad(): model.eval() test_loss = 0 correct = 0 for data, target in test_loader: data = utils.SetDevice(data) target = utils.SetDevice(target) output = model(data) # sum up batch loss test_loss += F.nll_loss(output, target, reduction='sum').item() # get the index of the max log-probability pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) accu = 100. * correct / len(test_loader.dataset) print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n' .format(test_loss, correct, len(test_loader.dataset), accu)) return accu