witnn/tools/Train.py

58 lines
1.9 KiB
Python

from torch.utils.data import Dataset, DataLoader
import numpy as np
import torchvision.models as models
from torchvision import datasets, transforms
import torchvision
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import torch
import os
import utils as utils
def train(model, train_loader, optimizer, epoch=0):
model.train()
batchsize = int(train_loader.sampler.num_samples /
train_loader.batch_size / 5)+1
for batch_idx, (data, target) in enumerate(train_loader):
data = utils.SetDevice(data)
target = utils.SetDevice(target)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % batchsize == 0 and batch_idx > 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'
.format(epoch, batch_idx * len(data),
len(train_loader.dataset),
100. * batch_idx /
len(train_loader),
loss.item()))
def test(model, test_loader):
with torch.no_grad():
model.eval()
test_loss = 0
correct = 0
for data, target in test_loader:
data = utils.SetDevice(data)
target = utils.SetDevice(target)
output = model(data)
# sum up batch loss
test_loss += F.nll_loss(output, target, reduction='sum').item()
# get the index of the max log-probability
pred = output.max(1, keepdim=True)[1]
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
accu = 100. * correct / len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'
.format(test_loss, correct, len(test_loader.dataset), accu))
return accu