from __future__ import print_function import os import sys import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torchvision from torchvision import datasets, transforms import torchvision.models as models import matplotlib.pyplot as plt import numpy as np import struct from struct import Struct # m = nn.BatchNorm2d(1) # m.weight.data = torch.ones(1) # m.bias.data = torch.ones(1) # m.running_mean.data = torch.ones(1)*2.0 # m.running_var.data = torch.zeros(1) # # Without Learnable Parameters # # m = nn.BatchNorm2d(1, affine=False) # input = torch.ones(1, 1, 4, 4) * 2.0 # output = m(input) # print(output) CurrentPath = os.path.split(os.path.realpath(__file__))[0]+"/" # resnet50 = models.resnet50(pretrained=True) # torch.save(resnet50, CurrentPath+'params.pth') resnet50 = torch.load(CurrentPath+'params.pth') resnet50.eval() # print("===========================") # print("===========================") # print("===========================") # print(resnet50) # print("===========================") # print("===========================") # print("===========================") ResNet50 = { "conv1": "Conv2d", "bn1": "BatchNorm2d", "relu": "ReLU", "maxpool": "MaxPool2d", "layer1": { "_modules": { "0": { "conv1": "Conv2d", "bn1": "BatchNorm2d", "conv2": "Conv2d", "bn2": "BatchNorm2d", "conv3": "Conv2d", "bn3": "BatchNorm2d", "relu": "ReLU", "downsample": { "_modules": { "0": "Conv2d", "1": "BatchNorm2d", } } }, "1": { "conv1": "Conv2d", "bn1": "BatchNorm2d", "conv2": "Conv2d", "bn2": "BatchNorm2d", "conv3": "Conv2d", "bn3": "BatchNorm2d", "relu": "ReLU", }, "2": { "conv1": "Conv2d", "bn1": "BatchNorm2d", "conv2": "Conv2d", "bn2": "BatchNorm2d", "conv3": "Conv2d", "bn3": "BatchNorm2d", "relu": "ReLU", }, } }, "layer2": { "_modules": { "0": { "conv1": "Conv2d", "bn1": "BatchNorm2d", "conv2": "Conv2d", "bn2": "BatchNorm2d", "conv3": "Conv2d", "bn3": "BatchNorm2d", "relu": "ReLU", "downsample": { "_modules": { "0": "Conv2d", "1": "BatchNorm2d", } } }, "1": { "conv1": "Conv2d", "bn1": "BatchNorm2d", "conv2": "Conv2d", "bn2": "BatchNorm2d", "conv3": "Conv2d", "bn3": "BatchNorm2d", "relu": "ReLU", }, "2": { "conv1": "Conv2d", "bn1": "BatchNorm2d", "conv2": "Conv2d", "bn2": "BatchNorm2d", "conv3": "Conv2d", "bn3": "BatchNorm2d", "relu": "ReLU", }, "3": { "conv1": "Conv2d", "bn1": "BatchNorm2d", "conv2": "Conv2d", "bn2": "BatchNorm2d", "conv3": "Conv2d", "bn3": "BatchNorm2d", "relu": "ReLU", } } }, "layer3": { "_modules": { "0": { "conv1": "Conv2d", "bn1": "BatchNorm2d", "conv2": "Conv2d", "bn2": "BatchNorm2d", "conv3": "Conv2d", "bn3": "BatchNorm2d", "relu": "ReLU", "downsample": { "_modules": { "0": "Conv2d", "1": "BatchNorm2d", } } }, "1": { "conv1": "Conv2d", "bn1": "BatchNorm2d", "conv2": "Conv2d", "bn2": "BatchNorm2d", "conv3": "Conv2d", "bn3": "BatchNorm2d", "relu": "ReLU", }, "2": { "conv1": "Conv2d", "bn1": "BatchNorm2d", "conv2": "Conv2d", "bn2": "BatchNorm2d", "conv3": "Conv2d", "bn3": "BatchNorm2d", "relu": "ReLU", }, "3": { "conv1": "Conv2d", "bn1": "BatchNorm2d", "conv2": "Conv2d", "bn2": "BatchNorm2d", "conv3": "Conv2d", "bn3": "BatchNorm2d", "relu": "ReLU", }, "4": { "conv1": "Conv2d", "bn1": "BatchNorm2d", "conv2": "Conv2d", "bn2": "BatchNorm2d", "conv3": "Conv2d", "bn3": "BatchNorm2d", "relu": "ReLU", }, "5": { "conv1": "Conv2d", "bn1": "BatchNorm2d", "conv2": "Conv2d", "bn2": "BatchNorm2d", "conv3": "Conv2d", "bn3": "BatchNorm2d", "relu": "ReLU", } } }, "layer4": { "_modules": { "0": { "conv1": "Conv2d", "bn1": "BatchNorm2d", "conv2": "Conv2d", "bn2": "BatchNorm2d", "conv3": "Conv2d", "bn3": "BatchNorm2d", "relu": "ReLU", "downsample": { "_modules": { "0": "Conv2d", "1": "BatchNorm2d", } } }, "1": { "conv1": "Conv2d", "bn1": "BatchNorm2d", "conv2": "Conv2d", "bn2": "BatchNorm2d", "conv3": "Conv2d", "bn3": "BatchNorm2d", "relu": "ReLU", }, "2": { "conv1": "Conv2d", "bn1": "BatchNorm2d", "conv2": "Conv2d", "bn2": "BatchNorm2d", "conv3": "Conv2d", "bn3": "BatchNorm2d", "relu": "ReLU", } } }, "avgpool": "AdaptiveAvgPool2d", "fc": "Linear" } weightfile = open(CurrentPath+'ResNet50Weight.cc', 'w') binaryfile = open(CurrentPath+'ResNet50Weight.bin', 'wb') currentbyte = 0 strg = '' def genData(name, data, currentbyte, binaryfile, strg): strg = strg + "int "+name+"[] = { " array = data.cpu().detach().numpy().reshape(-1) strg += str(currentbyte) + "," for a in array: bs = struct.pack("f", a) binaryfile.write(bs) currentbyte = currentbyte+4 strg += str(currentbyte-1) strg = strg + " };\n" strg = strg + "int "+name+"_shape[] = { " array = data.cpu().detach().numpy().shape for a in array: strg += str(a) + ", " strg = strg + " };\n" return (currentbyte,binaryfile,strg) def hook_fn(m, i, o): print(m) print("------------Input Grad------------") for grad in i: try: print(grad.shape) except AttributeError: print ("None found for Gradient") print("------------Output Grad------------") for grad in o: try: print(grad.shape) except AttributeError: print ("None found for Gradient") print("\n") def hook_print(name, m, i, o): global currentbyte global binaryfile global strg currentbyte, binaryfile, strg = genData( name+"_input", i[0], currentbyte, binaryfile, strg) currentbyte, binaryfile, strg = genData( name+"_output", o[0], currentbyte, binaryfile, strg) def printDick(d, head, obj): global currentbyte global binaryfile strg = "" for item in d: if type(d[item]).__name__ == 'dict': objsub = getattr(obj, item, '') if objsub == '': objsub = obj[item] strg = strg + printDick(d[item], head+"_"+item, objsub) else: objsub = getattr(obj, item, '') if objsub == '': objsub = obj[item] if d[item] == "Conv2d": currentbyte, binaryfile, strg = genData( head+"_"+item+"_weight", objsub.weight, currentbyte, binaryfile, strg) if d[item] == "BatchNorm2d": currentbyte, binaryfile, strg = genData( head+"_"+item+"_running_mean", objsub.running_mean, currentbyte, binaryfile, strg) currentbyte, binaryfile, strg = genData( head+"_"+item+"_running_var", objsub.running_var, currentbyte, binaryfile, strg) currentbyte, binaryfile, strg = genData( head+"_"+item+"_weight", objsub.weight, currentbyte, binaryfile, strg) currentbyte, binaryfile, strg = genData( head+"_"+item+"_bias", objsub.bias, currentbyte, binaryfile, strg) if d[item] == "Linear": currentbyte, binaryfile, strg = genData( head+"_"+item+"_weight", objsub.weight, currentbyte, binaryfile, strg) currentbyte, binaryfile, strg = genData( head+"_"+item+"_bias", objsub.bias, currentbyte, binaryfile, strg) return strg strg = printDick(ResNet50, "RN50", resnet50) normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) val_loader = torch.utils.data.DataLoader( datasets.ImageFolder(CurrentPath+'ImageNet/', transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])), batch_size=1, shuffle=False, num_workers=1, pin_memory=True) for batch_idx, (data, target) in enumerate(val_loader): currentbyte, binaryfile,strg = genData("input_"+str(batch_idx), data, currentbyte, binaryfile, strg) out = resnet50(data) currentbyte, binaryfile, strg = genData( "output_"+str(batch_idx), out, currentbyte, binaryfile, strg) for batch_idx, (data, target) in enumerate(val_loader): currentbyte,binaryfile,strg = genData("verify_input", data, currentbyte, binaryfile, strg) x = resnet50.conv1(data) currentbyte,binaryfile,strg = genData("verify_conv1", x, currentbyte, binaryfile, strg) x = resnet50.bn1(x) currentbyte,binaryfile,strg = genData("verify_bn1", x, currentbyte, binaryfile, strg) x = resnet50.relu(x) currentbyte,binaryfile,strg = genData("verify_relu", x, currentbyte, binaryfile, strg) x = resnet50.maxpool(x) currentbyte,binaryfile,strg = genData("verify_maxpool", x, currentbyte, binaryfile, strg) x = resnet50.layer1(x) currentbyte,binaryfile,strg = genData("verify_layer1", x, currentbyte, binaryfile, strg) x = resnet50.layer2(x) currentbyte,binaryfile,strg = genData("verify_layer2", x, currentbyte, binaryfile, strg) x = resnet50.layer3(x) currentbyte,binaryfile,strg = genData("verify_layer3", x, currentbyte, binaryfile, strg) x = resnet50.layer4(x) currentbyte,binaryfile,strg = genData("verify_layer4", x, currentbyte, binaryfile, strg) x = resnet50.avgpool(x) currentbyte,binaryfile,strg = genData("verify_avgpool", x, currentbyte, binaryfile, strg) x = torch.flatten(x, 1) x = resnet50.fc(x) currentbyte,binaryfile,strg = genData("verify_fc", x, currentbyte, binaryfile, strg) break resnet50.layer1._modules['0'].bn1.register_forward_hook(lambda m, i, o: hook_print("layer1_block0_bn1", m, i, o)) resnet50.layer1._modules['0'].bn2.register_forward_hook(lambda m, i, o: hook_print("layer1_block0_bn2", m, i, o)) resnet50.layer1._modules['0'].bn3.register_forward_hook(lambda m, i, o: hook_print("layer1_block0_bn3", m, i, o)) resnet50.layer1._modules['0'].conv1.register_forward_hook(lambda m, i, o: hook_print("layer1_block0_conv1", m, i, o)) resnet50.layer1._modules['0'].conv2.register_forward_hook(lambda m, i, o: hook_print("layer1_block0_conv2", m, i, o)) resnet50.layer1._modules['0'].conv3.register_forward_hook(lambda m, i, o: hook_print("layer1_block0_conv3", m, i, o)) resnet50.layer1._modules['0'].downsample._modules['0'].register_forward_hook(lambda m, i, o: hook_print("layer1_block0_downsample_conv", m, i, o)) resnet50.layer1._modules['0'].downsample._modules['1'].register_forward_hook(lambda m, i, o: hook_print("layer1_block0_downsample_bn", m, i, o)) resnet50.layer1._modules['1'].bn1.register_forward_hook(lambda m, i, o: hook_print("layer1_block1_bn1", m, i, o)) resnet50.layer1._modules['1'].bn2.register_forward_hook(lambda m, i, o: hook_print("layer1_block1_bn2", m, i, o)) resnet50.layer1._modules['1'].bn3.register_forward_hook(lambda m, i, o: hook_print("layer1_block1_bn3", m, i, o)) resnet50.layer1._modules['1'].conv1.register_forward_hook(lambda m, i, o: hook_print("layer1_block1_conv1", m, i, o)) resnet50.layer1._modules['1'].conv2.register_forward_hook(lambda m, i, o: hook_print("layer1_block1_conv2", m, i, o)) resnet50.layer1._modules['1'].conv3.register_forward_hook(lambda m, i, o: hook_print("layer1_block1_conv3", m, i, o)) resnet50.layer1._modules['2'].bn1.register_forward_hook(lambda m, i, o: hook_print("layer1_block2_bn1", m, i, o)) resnet50.layer1._modules['2'].bn2.register_forward_hook(lambda m, i, o: hook_print("layer1_block2_bn2", m, i, o)) resnet50.layer1._modules['2'].bn3.register_forward_hook(lambda m, i, o: hook_print("layer1_block2_bn3", m, i, o)) resnet50.layer1._modules['2'].conv1.register_forward_hook(lambda m, i, o: hook_print("layer1_block2_conv1", m, i, o)) resnet50.layer1._modules['2'].conv2.register_forward_hook(lambda m, i, o: hook_print("layer1_block2_conv2", m, i, o)) resnet50.layer1._modules['2'].conv3.register_forward_hook(lambda m, i, o: hook_print("layer1_block2_conv3", m, i, o)) for batch_idx, (data, target) in enumerate(val_loader): out = resnet50(data) break weightfile.write(strg) binaryfile.close() weightfile.close() # print(strg) print("===========================") print("===========================") print("===========================")