witnn/CNNDemo/CheckResnet50.py

68 lines
1.3 KiB
Python
Raw Normal View History

2020-07-27 20:27:00 +08:00
from __future__ import print_function
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import torchvision.models as models
import matplotlib.pyplot as plt
import numpy as np
import struct
from struct import Struct
CurrentPath = os.path.split(os.path.realpath(__file__))[0]+"/"
# resnet50 = models.resnet50(pretrained=True)
# torch.save(resnet50, CurrentPath+'params.pth')
resnet50 = torch.load(CurrentPath+'params.pth')
resnet50.eval()
weightfile = open(CurrentPath+'ResNet50Weight.cc', 'r')
binaryfile = open(CurrentPath+'ResNet50Weight.bin', 'rb')
bindata = binaryfile.read()
def toNumpy(start, end):
global bindata
byte = []
for i in range(start,end,4):
byte.append(struct.unpack("f", bindata[i:i+4])[0])
return np.array(byte)
weight = toNumpy(38656,55039)
inp = toNumpy(125573760,126376575)
outp = toNumpy(126376576,127179391)
m = resnet50.layer1._modules['0'].conv1
weight = weight.reshape(m.weight.shape)
inp = inp.reshape(1, 64, 56, 56,)
outp = outp.reshape( 64, 56, 56,)
m.weight.data = torch.from_numpy(weight)
output = m(torch.from_numpy(inp))
output = output.cpu().detach().numpy()
print(output[0,0,:])
print(outp[0,0,:])
i = 0