27 lines
732 B
Python
27 lines
732 B
Python
|
from __future__ import print_function
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
import torch.optim as optim
|
||
|
import torchvision
|
||
|
from torchvision import datasets, transforms
|
||
|
import torchvision.models as models
|
||
|
import matplotlib.pyplot as plt
|
||
|
import numpy as np
|
||
|
|
||
|
|
||
|
class ModuleBase(nn.Module):
|
||
|
def __init__(self):
|
||
|
super(ModuleBase, self).__init__()
|
||
|
|
||
|
def SetConvRequiresGrad(self, layer, requiregrad):
|
||
|
self.features[layer].weight.requires_grad = requiregrad
|
||
|
b = self.features[layer].bias
|
||
|
if b != None:
|
||
|
self.features[layer].bias.requires_grad = requiregrad
|
||
|
|
||
|
def forwardLayer(self, x, layer=0):
|
||
|
layers = self.features[0:layer+1]
|
||
|
x = layers(x)
|
||
|
|