fix resnet to eval static , fix batch normal error

This commit is contained in:
colin 2020-07-23 15:26:40 +08:00
parent ee0a12f7dd
commit 7a3dfde615
2 changed files with 37 additions and 1 deletions

Binary file not shown.

View File

@ -15,12 +15,30 @@ import struct
from struct import Struct from struct import Struct
# m = nn.BatchNorm2d(1)
# m.weight.data = torch.ones(1)
# m.bias.data = torch.ones(1)
# m.running_mean.data = torch.ones(1)*2.0
# m.running_var.data = torch.zeros(1)
# # Without Learnable Parameters
# # m = nn.BatchNorm2d(1, affine=False)
# input = torch.ones(1, 1, 4, 4) * 2.0
# output = m(input)
# print(output)
CurrentPath = os.path.split(os.path.realpath(__file__))[0]+"/" CurrentPath = os.path.split(os.path.realpath(__file__))[0]+"/"
resnet50 = models.resnet50(pretrained=True) resnet50 = models.resnet50(pretrained=True)
# torch.save(resnet50, CurrentPath+'params.pth') # torch.save(resnet50, CurrentPath+'params.pth')
resnet50 = torch.load(CurrentPath+'params.pth') resnet50 = torch.load(CurrentPath+'params.pth')
resnet50.eval()
print("===========================") print("===========================")
print("===========================") print("===========================")
print("===========================") print("===========================")
@ -282,8 +300,26 @@ for batch_idx, (data, target) in enumerate(val_loader):
currentbyte,binaryfile,strg = genData("verify_input", data, currentbyte, binaryfile, strg) currentbyte,binaryfile,strg = genData("verify_input", data, currentbyte, binaryfile, strg)
x = resnet50.conv1(data) x = resnet50.conv1(data)
currentbyte,binaryfile,strg = genData("verify_conv1", x, currentbyte, binaryfile, strg) currentbyte,binaryfile,strg = genData("verify_conv1", x, currentbyte, binaryfile, strg)
print("bias")
print(resnet50.bn1.bias.data)
print("weight")
print(resnet50.bn1.weight.data)
print("mean")
print(resnet50.bn1.running_mean.data)
print("var")
print(resnet50.bn1.running_var.data)
print("input")
print(x)
x = resnet50.bn1(x) x = resnet50.bn1(x)
currentbyte,binaryfile,strg = genData("verify_bn1", x, currentbyte, binaryfile, strg) currentbyte,binaryfile,strg = genData("verify_bn1", x, currentbyte, binaryfile, strg)
print("output")
print(x)
x = resnet50.relu(x) x = resnet50.relu(x)
currentbyte,binaryfile,strg = genData("verify_relu", x, currentbyte, binaryfile, strg) currentbyte,binaryfile,strg = genData("verify_relu", x, currentbyte, binaryfile, strg)
x = resnet50.maxpool(x) x = resnet50.maxpool(x)