witnn/tools/UniModule.py

27 lines
732 B
Python
Raw Normal View History

2019-08-19 15:53:10 +08:00
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import torchvision.models as models
import matplotlib.pyplot as plt
import numpy as np
class ModuleBase(nn.Module):
def __init__(self):
super(ModuleBase, self).__init__()
def SetConvRequiresGrad(self, layer, requiregrad):
self.features[layer].weight.requires_grad = requiregrad
b = self.features[layer].bias
if b != None:
self.features[layer].bias.requires_grad = requiregrad
def forwardLayer(self, x, layer=0):
layers = self.features[0:layer+1]
x = layers(x)