fix resnet to eval static , fix batch normal error
This commit is contained in:
parent
ee0a12f7dd
commit
7a3dfde615
Binary file not shown.
|
@ -15,12 +15,30 @@ import struct
|
||||||
from struct import Struct
|
from struct import Struct
|
||||||
|
|
||||||
|
|
||||||
|
# m = nn.BatchNorm2d(1)
|
||||||
|
|
||||||
|
# m.weight.data = torch.ones(1)
|
||||||
|
# m.bias.data = torch.ones(1)
|
||||||
|
# m.running_mean.data = torch.ones(1)*2.0
|
||||||
|
# m.running_var.data = torch.zeros(1)
|
||||||
|
|
||||||
|
# # Without Learnable Parameters
|
||||||
|
# # m = nn.BatchNorm2d(1, affine=False)
|
||||||
|
# input = torch.ones(1, 1, 4, 4) * 2.0
|
||||||
|
# output = m(input)
|
||||||
|
# print(output)
|
||||||
|
|
||||||
|
|
||||||
CurrentPath = os.path.split(os.path.realpath(__file__))[0]+"/"
|
CurrentPath = os.path.split(os.path.realpath(__file__))[0]+"/"
|
||||||
|
|
||||||
|
|
||||||
resnet50 = models.resnet50(pretrained=True)
|
resnet50 = models.resnet50(pretrained=True)
|
||||||
|
|
||||||
|
|
||||||
# torch.save(resnet50, CurrentPath+'params.pth')
|
# torch.save(resnet50, CurrentPath+'params.pth')
|
||||||
resnet50 = torch.load(CurrentPath+'params.pth')
|
resnet50 = torch.load(CurrentPath+'params.pth')
|
||||||
|
resnet50.eval()
|
||||||
|
|
||||||
|
|
||||||
print("===========================")
|
print("===========================")
|
||||||
print("===========================")
|
print("===========================")
|
||||||
print("===========================")
|
print("===========================")
|
||||||
|
@ -282,8 +300,26 @@ for batch_idx, (data, target) in enumerate(val_loader):
|
||||||
currentbyte,binaryfile,strg = genData("verify_input", data, currentbyte, binaryfile, strg)
|
currentbyte,binaryfile,strg = genData("verify_input", data, currentbyte, binaryfile, strg)
|
||||||
x = resnet50.conv1(data)
|
x = resnet50.conv1(data)
|
||||||
currentbyte,binaryfile,strg = genData("verify_conv1", x, currentbyte, binaryfile, strg)
|
currentbyte,binaryfile,strg = genData("verify_conv1", x, currentbyte, binaryfile, strg)
|
||||||
|
|
||||||
|
|
||||||
|
print("bias")
|
||||||
|
print(resnet50.bn1.bias.data)
|
||||||
|
print("weight")
|
||||||
|
print(resnet50.bn1.weight.data)
|
||||||
|
print("mean")
|
||||||
|
print(resnet50.bn1.running_mean.data)
|
||||||
|
print("var")
|
||||||
|
print(resnet50.bn1.running_var.data)
|
||||||
|
print("input")
|
||||||
|
print(x)
|
||||||
|
|
||||||
x = resnet50.bn1(x)
|
x = resnet50.bn1(x)
|
||||||
currentbyte,binaryfile,strg = genData("verify_bn1", x, currentbyte, binaryfile, strg)
|
currentbyte,binaryfile,strg = genData("verify_bn1", x, currentbyte, binaryfile, strg)
|
||||||
|
|
||||||
|
|
||||||
|
print("output")
|
||||||
|
print(x)
|
||||||
|
|
||||||
x = resnet50.relu(x)
|
x = resnet50.relu(x)
|
||||||
currentbyte,binaryfile,strg = genData("verify_relu", x, currentbyte, binaryfile, strg)
|
currentbyte,binaryfile,strg = genData("verify_relu", x, currentbyte, binaryfile, strg)
|
||||||
x = resnet50.maxpool(x)
|
x = resnet50.maxpool(x)
|
||||||
|
|
Loading…
Reference in New Issue