2019-08-19 15:53:10 +08:00
|
|
|
from __future__ import print_function
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import torch.optim as optim
|
|
|
|
import torchvision
|
|
|
|
from torchvision import datasets, transforms
|
|
|
|
import torchvision.models as models
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
class ModuleBase(nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super(ModuleBase, self).__init__()
|
|
|
|
|
|
|
|
def SetConvRequiresGrad(self, layer, requiregrad):
|
|
|
|
self.features[layer].weight.requires_grad = requiregrad
|
|
|
|
b = self.features[layer].bias
|
|
|
|
if b != None:
|
|
|
|
self.features[layer].bias.requires_grad = requiregrad
|
|
|
|
|
2019-08-27 23:06:24 +08:00
|
|
|
def ForwardLayer(self, x, layer=0):
|
|
|
|
if layer < 0:
|
|
|
|
return x
|
2019-08-19 15:53:10 +08:00
|
|
|
layers = self.features[0:layer+1]
|
|
|
|
x = layers(x)
|
2019-08-27 23:06:24 +08:00
|
|
|
return x
|
|
|
|
|
|
|
|
def PrintLayer(self):
|
|
|
|
names = []
|
|
|
|
i = 0
|
|
|
|
for l in self.features:
|
|
|
|
names.append(str(i)+' : '+str(l))
|
|
|
|
i = i+1
|
|
|
|
return names
|