From ab65f98e397bf59b0d8d41eb824c2d32a923b6a6 Mon Sep 17 00:00:00 2001 From: Colin Date: Fri, 27 Jun 2025 13:49:15 +0800 Subject: [PATCH] Add weight dump. --- binary/mnist.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/binary/mnist.py b/binary/mnist.py index 3dd4b51..1e2b42d 100644 --- a/binary/mnist.py +++ b/binary/mnist.py @@ -1,7 +1,10 @@ import os +import sys os.environ["CUDA_LAUNCH_BLOCKING"] = "1" +sys.path.append("..") +from tools import show import torch import torch.nn as nn @@ -268,6 +271,11 @@ class SimpleBNN(nn.Module): return x def printWeight(self): + show.DumpTensorToImage(self.lnn1.lut.weight.data.permute(1, 0).unsqueeze(-2), "./temp/lnn1.png") + show.DumpTensorToImage(self.lnn2.lut.weight.data.permute(1, 0).unsqueeze(-2), "./temp/lnn2.png") + show.DumpTensorToImage(self.lnn3.lut.weight.data.permute(1, 0).unsqueeze(-2), "./temp/lnn3.png") + show.DumpTensorToImage(self.lnn4.lut.weight.data.permute(1, 0).unsqueeze(-2), "./temp/lnn4.png") + show.DumpTensorToImage(self.lnn5.lut.weight.data.permute(1, 0).unsqueeze(-2), "./temp/lnn5.png") pass