from __future__ import print_function import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torchvision from torchvision import datasets, transforms import torchvision.models as models import matplotlib.pyplot as plt import numpy as np class ModuleBase(nn.Module): def __init__(self): super(ModuleBase, self).__init__() def SetConvRequiresGrad(self, layer, requiregrad): self.features[layer].weight.requires_grad = requiregrad b = self.features[layer].bias if b != None: self.features[layer].bias.requires_grad = requiregrad def ForwardLayer(self, x, layer=0): if layer < 0: return x layers = self.features[0:layer+1] x = layers(x) return x def PrintLayer(self): names = [] i = 0 for l in self.features: names.append(str(i)+' : '+str(l)) i = i+1 return names