witnn/tools/UniModule.py

37 lines
974 B
Python

from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import torchvision.models as models
import matplotlib.pyplot as plt
import numpy as np
class ModuleBase(nn.Module):
def __init__(self):
super(ModuleBase, self).__init__()
def SetConvRequiresGrad(self, layer, requiregrad):
self.features[layer].weight.requires_grad = requiregrad
b = self.features[layer].bias
if b != None:
self.features[layer].bias.requires_grad = requiregrad
def ForwardLayer(self, x, layer=0):
if layer < 0:
return x
layers = self.features[0:layer+1]
x = layers(x)
return x
def PrintLayer(self):
names = []
i = 0
for l in self.features:
names.append(str(i)+' : '+str(l))
i = i+1
return names