fix resnet to eval static , fix batch normal error
This commit is contained in:
		
							parent
							
								
									ee0a12f7dd
								
							
						
					
					
						commit
						7a3dfde615
					
				
										
											Binary file not shown.
										
									
								
							|  | @ -15,12 +15,30 @@ import struct | ||||||
| from struct import Struct | from struct import Struct | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | # m = nn.BatchNorm2d(1) | ||||||
|  | 
 | ||||||
|  | # m.weight.data = torch.ones(1) | ||||||
|  | # m.bias.data = torch.ones(1) | ||||||
|  | # m.running_mean.data = torch.ones(1)*2.0 | ||||||
|  | # m.running_var.data = torch.zeros(1) | ||||||
|  | 
 | ||||||
|  | # # Without Learnable Parameters | ||||||
|  | # # m = nn.BatchNorm2d(1, affine=False) | ||||||
|  | # input = torch.ones(1, 1, 4, 4) * 2.0 | ||||||
|  | # output = m(input) | ||||||
|  | # print(output) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| CurrentPath = os.path.split(os.path.realpath(__file__))[0]+"/" | CurrentPath = os.path.split(os.path.realpath(__file__))[0]+"/" | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| resnet50 = models.resnet50(pretrained=True) | resnet50 = models.resnet50(pretrained=True) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| # torch.save(resnet50, CurrentPath+'params.pth') | # torch.save(resnet50, CurrentPath+'params.pth') | ||||||
| resnet50 = torch.load(CurrentPath+'params.pth') | resnet50 = torch.load(CurrentPath+'params.pth') | ||||||
|  | resnet50.eval() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| print("===========================") | print("===========================") | ||||||
| print("===========================") | print("===========================") | ||||||
| print("===========================") | print("===========================") | ||||||
|  | @ -282,8 +300,26 @@ for batch_idx, (data, target) in enumerate(val_loader): | ||||||
|     currentbyte,binaryfile,strg = genData("verify_input", data, currentbyte, binaryfile, strg) |     currentbyte,binaryfile,strg = genData("verify_input", data, currentbyte, binaryfile, strg) | ||||||
|     x = resnet50.conv1(data) |     x = resnet50.conv1(data) | ||||||
|     currentbyte,binaryfile,strg = genData("verify_conv1", x, currentbyte, binaryfile, strg) |     currentbyte,binaryfile,strg = genData("verify_conv1", x, currentbyte, binaryfile, strg) | ||||||
|  |      | ||||||
|  | 
 | ||||||
|  |     print("bias") | ||||||
|  |     print(resnet50.bn1.bias.data) | ||||||
|  |     print("weight") | ||||||
|  |     print(resnet50.bn1.weight.data)   | ||||||
|  |     print("mean") | ||||||
|  |     print(resnet50.bn1.running_mean.data)   | ||||||
|  |     print("var") | ||||||
|  |     print(resnet50.bn1.running_var.data) | ||||||
|  |     print("input") | ||||||
|  |     print(x) | ||||||
|  | 
 | ||||||
|     x = resnet50.bn1(x) |     x = resnet50.bn1(x) | ||||||
|     currentbyte,binaryfile,strg = genData("verify_bn1", x, currentbyte, binaryfile, strg) |     currentbyte,binaryfile,strg = genData("verify_bn1", x, currentbyte, binaryfile, strg) | ||||||
|  |      | ||||||
|  | 
 | ||||||
|  |     print("output") | ||||||
|  |     print(x) | ||||||
|  |      | ||||||
|     x = resnet50.relu(x) |     x = resnet50.relu(x) | ||||||
|     currentbyte,binaryfile,strg = genData("verify_relu", x, currentbyte, binaryfile, strg) |     currentbyte,binaryfile,strg = genData("verify_relu", x, currentbyte, binaryfile, strg) | ||||||
|     x = resnet50.maxpool(x) |     x = resnet50.maxpool(x) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue