63 lines
		
	
	
		
			2.2 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			63 lines
		
	
	
		
			2.2 KiB
		
	
	
	
		
			Python
		
	
	
	
from torch.utils.data import Dataset, DataLoader
 | 
						|
import numpy as np
 | 
						|
import torchvision.models as models
 | 
						|
from torchvision import datasets, transforms
 | 
						|
import torchvision
 | 
						|
import torch.optim as optim
 | 
						|
import torch.nn.functional as F
 | 
						|
import torch.nn as nn
 | 
						|
import torch
 | 
						|
import os
 | 
						|
import utils as utils
 | 
						|
from tqdm import tqdm
 | 
						|
 | 
						|
 | 
						|
def train(model, train_loader, optimizer, epoch=0):
 | 
						|
    model.train()
 | 
						|
    totalsize = train_loader.batch_sampler.sampler.num_samples
 | 
						|
    batchsize = int(totalsize / train_loader.batch_size / 5)+1
 | 
						|
    pbar = tqdm(totalsize)
 | 
						|
    for batch_idx, (data, target) in enumerate(train_loader):
 | 
						|
        data = utils.SetDevice(data)
 | 
						|
        target = utils.SetDevice(target)
 | 
						|
        optimizer.zero_grad()
 | 
						|
        output = model(data)
 | 
						|
        loss = F.nll_loss(output, target)
 | 
						|
        loss.backward()
 | 
						|
        optimizer.step()
 | 
						|
 | 
						|
        pbar.update(train_loader.batch_size)
 | 
						|
        if batch_idx % batchsize == 0 and batch_idx > 0:
 | 
						|
            pbar.set_description("Loss:"+str(loss.item()))
 | 
						|
    pbar.close()
 | 
						|
 | 
						|
def test(model, test_loader):
 | 
						|
    with torch.no_grad():
 | 
						|
        model.eval()
 | 
						|
        test_loss = 0
 | 
						|
        correct = 0
 | 
						|
        for data, target in test_loader:
 | 
						|
            data = utils.SetDevice(data)
 | 
						|
            target = utils.SetDevice(target)
 | 
						|
            output = model(data)
 | 
						|
            # sum up batch loss
 | 
						|
            test_loss += F.nll_loss(output, target, reduction='sum').item()
 | 
						|
            # get the index of the max log-probability
 | 
						|
            pred = output.max(1, keepdim=True)[1]
 | 
						|
            correct += pred.eq(target.view_as(pred)).sum().item()
 | 
						|
        test_loss /= len(test_loader.dataset)
 | 
						|
        accu = 100. * correct / len(test_loader.dataset)
 | 
						|
        # print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'
 | 
						|
        #       .format(test_loss, correct, len(test_loader.dataset), accu))
 | 
						|
        return accu
 | 
						|
 | 
						|
 | 
						|
def TrainEpochs(model, traindata, optimizer, testdata, epoch=100,testepoch=10, line=None):
 | 
						|
    epochbar = tqdm(total=epoch)
 | 
						|
    for i in range(epoch):
 | 
						|
        train(model, traindata, optimizer, epoch=i)
 | 
						|
        if line and i % testepoch == 0 and i > 0:
 | 
						|
            line.AppendData(test(model, testdata))
 | 
						|
        epochbar.update(1)
 | 
						|
    epochbar.close()
 |