diff --git a/binary/mnist.py b/binary/mnist.py index 3dd4b51..1e2b42d 100644 --- a/binary/mnist.py +++ b/binary/mnist.py @@ -1,7 +1,10 @@ import os +import sys os.environ["CUDA_LAUNCH_BLOCKING"] = "1" +sys.path.append("..") +from tools import show import torch import torch.nn as nn @@ -268,6 +271,11 @@ class SimpleBNN(nn.Module): return x def printWeight(self): + show.DumpTensorToImage(self.lnn1.lut.weight.data.permute(1, 0).unsqueeze(-2), "./temp/lnn1.png") + show.DumpTensorToImage(self.lnn2.lut.weight.data.permute(1, 0).unsqueeze(-2), "./temp/lnn2.png") + show.DumpTensorToImage(self.lnn3.lut.weight.data.permute(1, 0).unsqueeze(-2), "./temp/lnn3.png") + show.DumpTensorToImage(self.lnn4.lut.weight.data.permute(1, 0).unsqueeze(-2), "./temp/lnn4.png") + show.DumpTensorToImage(self.lnn5.lut.weight.data.permute(1, 0).unsqueeze(-2), "./temp/lnn5.png") pass