witnn/tools/Train.py

63 lines
2.2 KiB
Python

from torch.utils.data import Dataset, DataLoader
import numpy as np
import torchvision.models as models
from torchvision import datasets, transforms
import torchvision
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import torch
import os
import utils as utils
from tqdm import tqdm
def train(model, train_loader, optimizer, epoch=0):
model.train()
totalsize = train_loader.batch_sampler.sampler.num_samples
batchsize = int(totalsize / train_loader.batch_size / 5)+1
pbar = tqdm(totalsize)
for batch_idx, (data, target) in enumerate(train_loader):
data = utils.SetDevice(data)
target = utils.SetDevice(target)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
pbar.update(train_loader.batch_size)
if batch_idx % batchsize == 0 and batch_idx > 0:
pbar.set_description("Loss:"+str(loss.item()))
pbar.close()
def test(model, test_loader):
with torch.no_grad():
model.eval()
test_loss = 0
correct = 0
for data, target in test_loader:
data = utils.SetDevice(data)
target = utils.SetDevice(target)
output = model(data)
# sum up batch loss
test_loss += F.nll_loss(output, target, reduction='sum').item()
# get the index of the max log-probability
pred = output.max(1, keepdim=True)[1]
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
accu = 100. * correct / len(test_loader.dataset)
# print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'
# .format(test_loss, correct, len(test_loader.dataset), accu))
return accu
def TrainEpochs(model, traindata, optimizer, testdata, epoch=100,testepoch=10, line=None):
epochbar = tqdm(total=epoch)
for i in range(epoch):
train(model, traindata, optimizer, epoch=i)
if line and i % testepoch == 0 and i > 0:
line.AppendData(test(model, testdata))
epochbar.update(1)
epochbar.close()