witnn/CNNDemo/Resnet50.py

360 lines
11 KiB
Python
Raw Normal View History

2020-07-18 11:23:58 +08:00
from __future__ import print_function
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import torchvision.models as models
import matplotlib.pyplot as plt
import numpy as np
import struct
from struct import Struct
2020-07-20 15:45:02 +08:00
# m = nn.BatchNorm2d(1)
# m.weight.data = torch.ones(1)
# m.bias.data = torch.ones(1)
# m.running_mean.data = torch.ones(1)*2.0
# m.running_var.data = torch.zeros(1)
# # Without Learnable Parameters
# # m = nn.BatchNorm2d(1, affine=False)
# input = torch.ones(1, 1, 4, 4) * 2.0
# output = m(input)
# print(output)
2020-07-18 11:23:58 +08:00
CurrentPath = os.path.split(os.path.realpath(__file__))[0]+"/"
2020-07-18 11:23:58 +08:00
resnet50 = models.resnet50(pretrained=True)
2020-07-20 15:45:02 +08:00
# torch.save(resnet50, CurrentPath+'params.pth')
2020-07-18 11:23:58 +08:00
resnet50 = torch.load(CurrentPath+'params.pth')
resnet50.eval()
2020-07-18 11:23:58 +08:00
print("===========================")
print("===========================")
print("===========================")
print(resnet50)
print("===========================")
print("===========================")
print("===========================")
ResNet50 = {
"conv1": "Conv2d",
"bn1": "BatchNorm2d",
"relu": "ReLU",
"maxpool": "MaxPool2d",
"layer1": {
"_modules": {
"0": {
"conv1": "Conv2d",
"bn1": "BatchNorm2d",
"conv2": "Conv2d",
"bn2": "BatchNorm2d",
"conv3": "Conv2d",
"bn3": "BatchNorm2d",
"relu": "ReLU",
"downsample": {
"_modules": {
"0": "Conv2d",
"1": "BatchNorm2d",
}
}
2020-07-24 14:03:07 +08:00
},
"1": {
"conv1": "Conv2d",
"bn1": "BatchNorm2d",
"conv2": "Conv2d",
"bn2": "BatchNorm2d",
"conv3": "Conv2d",
"bn3": "BatchNorm2d",
"relu": "ReLU",
},
"2": {
"conv1": "Conv2d",
"bn1": "BatchNorm2d",
"conv2": "Conv2d",
"bn2": "BatchNorm2d",
"conv3": "Conv2d",
"bn3": "BatchNorm2d",
"relu": "ReLU",
},
2020-07-18 11:23:58 +08:00
}
},
"layer2": {
"_modules": {
"0": {
"conv1": "Conv2d",
"bn1": "BatchNorm2d",
"conv2": "Conv2d",
"bn2": "BatchNorm2d",
"conv3": "Conv2d",
"bn3": "BatchNorm2d",
"relu": "ReLU",
"downsample": {
"_modules": {
"0": "Conv2d",
"1": "BatchNorm2d",
}
}
},
"1": {
"conv1": "Conv2d",
"bn1": "BatchNorm2d",
"conv2": "Conv2d",
"bn2": "BatchNorm2d",
"conv3": "Conv2d",
"bn3": "BatchNorm2d",
"relu": "ReLU",
},
"2": {
"conv1": "Conv2d",
"bn1": "BatchNorm2d",
"conv2": "Conv2d",
"bn2": "BatchNorm2d",
"conv3": "Conv2d",
"bn3": "BatchNorm2d",
"relu": "ReLU",
},
"3": {
"conv1": "Conv2d",
"bn1": "BatchNorm2d",
"conv2": "Conv2d",
"bn2": "BatchNorm2d",
"conv3": "Conv2d",
"bn3": "BatchNorm2d",
"relu": "ReLU",
}
}
},
"layer3": {
"_modules": {
"0": {
"conv1": "Conv2d",
"bn1": "BatchNorm2d",
"conv2": "Conv2d",
"bn2": "BatchNorm2d",
"conv3": "Conv2d",
"bn3": "BatchNorm2d",
"relu": "ReLU",
"downsample": {
"_modules": {
"0": "Conv2d",
"1": "BatchNorm2d",
}
}
},
"1": {
"conv1": "Conv2d",
"bn1": "BatchNorm2d",
"conv2": "Conv2d",
"bn2": "BatchNorm2d",
"conv3": "Conv2d",
"bn3": "BatchNorm2d",
"relu": "ReLU",
},
"2": {
"conv1": "Conv2d",
"bn1": "BatchNorm2d",
"conv2": "Conv2d",
"bn2": "BatchNorm2d",
"conv3": "Conv2d",
"bn3": "BatchNorm2d",
"relu": "ReLU",
},
"3": {
"conv1": "Conv2d",
"bn1": "BatchNorm2d",
"conv2": "Conv2d",
"bn2": "BatchNorm2d",
"conv3": "Conv2d",
"bn3": "BatchNorm2d",
"relu": "ReLU",
},
"4": {
"conv1": "Conv2d",
"bn1": "BatchNorm2d",
"conv2": "Conv2d",
"bn2": "BatchNorm2d",
"conv3": "Conv2d",
"bn3": "BatchNorm2d",
"relu": "ReLU",
},
"5": {
"conv1": "Conv2d",
"bn1": "BatchNorm2d",
"conv2": "Conv2d",
"bn2": "BatchNorm2d",
"conv3": "Conv2d",
"bn3": "BatchNorm2d",
"relu": "ReLU",
}
}
},
"layer4": {
"_modules": {
"0": {
"conv1": "Conv2d",
"bn1": "BatchNorm2d",
"conv2": "Conv2d",
"bn2": "BatchNorm2d",
"conv3": "Conv2d",
"bn3": "BatchNorm2d",
"relu": "ReLU",
"downsample": {
"_modules": {
"0": "Conv2d",
"1": "BatchNorm2d",
}
}
},
"1": {
"conv1": "Conv2d",
"bn1": "BatchNorm2d",
"conv2": "Conv2d",
"bn2": "BatchNorm2d",
"conv3": "Conv2d",
"bn3": "BatchNorm2d",
"relu": "ReLU",
},
"2": {
"conv1": "Conv2d",
"bn1": "BatchNorm2d",
"conv2": "Conv2d",
"bn2": "BatchNorm2d",
"conv3": "Conv2d",
"bn3": "BatchNorm2d",
"relu": "ReLU",
}
}
},
"avgpool": "AdaptiveAvgPool2d",
"fc": "Linear"
}
weightfile = open(CurrentPath+'ResNet50Weight.cc', 'w')
binaryfile = open(CurrentPath+'ResNet50Weight.bin', 'wb')
currentbyte = 0
2020-07-20 17:35:05 +08:00
def genData(name, data, currentbyte, binaryfile, strg):
strg = strg + "int "+name+"[] = { "
array = data.cpu().detach().numpy().reshape(-1)
strg += str(currentbyte) + ","
for a in array:
bs = struct.pack("f", a)
binaryfile.write(bs)
currentbyte = currentbyte+4
strg += str(currentbyte-1)
strg = strg + " };\n"
return (currentbyte,binaryfile,strg)
2020-07-18 11:23:58 +08:00
def printDick(d, head, obj):
global currentbyte
global binaryfile
strg = ""
for item in d:
if type(d[item]).__name__ == 'dict':
objsub = getattr(obj, item, '')
if objsub == '':
objsub = obj[item]
strg = strg + printDick(d[item], head+"_"+item, objsub)
else:
objsub = getattr(obj, item, '')
if objsub == '':
objsub = obj[item]
if d[item] == "Conv2d":
2020-07-20 17:35:05 +08:00
currentbyte, binaryfile, strg = genData(
head+"_"+item+"_weight", objsub.weight, currentbyte, binaryfile, strg)
2020-07-18 11:23:58 +08:00
if d[item] == "BatchNorm2d":
2020-07-20 17:35:05 +08:00
currentbyte, binaryfile, strg = genData(
head+"_"+item+"_running_mean", objsub.running_mean, currentbyte, binaryfile, strg)
currentbyte, binaryfile, strg = genData(
head+"_"+item+"_running_var", objsub.running_var, currentbyte, binaryfile, strg)
currentbyte, binaryfile, strg = genData(
head+"_"+item+"_weight", objsub.weight, currentbyte, binaryfile, strg)
currentbyte, binaryfile, strg = genData(
head+"_"+item+"_bias", objsub.bias, currentbyte, binaryfile, strg)
2020-07-20 13:05:12 +08:00
2020-07-18 11:23:58 +08:00
if d[item] == "Linear":
2020-07-20 17:35:05 +08:00
currentbyte, binaryfile, strg = genData(
head+"_"+item+"_weight", objsub.weight, currentbyte, binaryfile, strg)
currentbyte, binaryfile, strg = genData(
head+"_"+item+"_bias", objsub.bias, currentbyte, binaryfile, strg)
2020-07-18 11:23:58 +08:00
return strg
2020-07-20 16:00:32 +08:00
2020-07-20 15:45:02 +08:00
strg = ''
strg = printDick(ResNet50, "RN50", resnet50)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(CurrentPath+'ImageNet/', transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])),
batch_size=1, shuffle=False,
num_workers=1, pin_memory=True)
2020-07-18 11:23:58 +08:00
2020-07-20 15:45:02 +08:00
for batch_idx, (data, target) in enumerate(val_loader):
2020-07-20 17:35:05 +08:00
currentbyte,binaryfile,strg = genData("verify_input", data, currentbyte, binaryfile, strg)
x = resnet50.conv1(data)
currentbyte,binaryfile,strg = genData("verify_conv1", x, currentbyte, binaryfile, strg)
x = resnet50.bn1(x)
currentbyte,binaryfile,strg = genData("verify_bn1", x, currentbyte, binaryfile, strg)
x = resnet50.relu(x)
currentbyte,binaryfile,strg = genData("verify_relu", x, currentbyte, binaryfile, strg)
x = resnet50.maxpool(x)
currentbyte,binaryfile,strg = genData("verify_maxpool", x, currentbyte, binaryfile, strg)
x = resnet50.layer1(x)
currentbyte,binaryfile,strg = genData("verify_layer1", x, currentbyte, binaryfile, strg)
x = resnet50.layer2(x)
currentbyte,binaryfile,strg = genData("verify_layer2", x, currentbyte, binaryfile, strg)
x = resnet50.layer3(x)
currentbyte,binaryfile,strg = genData("verify_layer3", x, currentbyte, binaryfile, strg)
x = resnet50.layer4(x)
currentbyte,binaryfile,strg = genData("verify_layer4", x, currentbyte, binaryfile, strg)
x = resnet50.avgpool(x)
currentbyte,binaryfile,strg = genData("verify_avgpool", x, currentbyte, binaryfile, strg)
x = torch.flatten(x, 1)
x = resnet50.fc(x)
currentbyte,binaryfile,strg = genData("verify_fc", x, currentbyte, binaryfile, strg)
break
2020-07-20 15:45:02 +08:00
2020-07-20 17:35:05 +08:00
for batch_idx, (data, target) in enumerate(val_loader):
currentbyte, binaryfile,strg = genData("input_"+str(batch_idx), data, currentbyte, binaryfile, strg)
2020-07-20 15:45:02 +08:00
out = resnet50(data)
2020-07-20 17:35:05 +08:00
currentbyte, binaryfile, strg = genData(
"output_"+str(batch_idx), out, currentbyte, binaryfile, strg)
2020-07-20 15:45:02 +08:00
weightfile.write(strg)
2020-07-18 11:23:58 +08:00
binaryfile.close()
weightfile.close()
2020-07-20 15:45:02 +08:00
print(strg)
2020-07-18 11:23:58 +08:00
print("===========================")
print("===========================")
2020-07-20 17:35:05 +08:00
print("===========================")