from __future__ import print_function import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torchvision from torchvision import datasets, transforms import torchvision.models as models import matplotlib.pyplot as plt import numpy as np class ModuleBase(nn.Module): def __init__(self): super(ModuleBase, self).__init__() def SetConvRequiresGrad(self, layer, requiregrad): self.features[layer].weight.requires_grad = requiregrad b = self.features[layer].bias if b != None: self.features[layer].bias.requires_grad = requiregrad def forwardLayer(self, x, layer=0): layers = self.features[0:layer+1] x = layers(x)