Add weight dump.

This commit is contained in:
Colin 2025-06-27 13:49:15 +08:00
parent 9acd89c98e
commit ab65f98e39
1 changed files with 8 additions and 0 deletions

View File

@ -1,7 +1,10 @@
import os
import sys
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
sys.path.append("..")
from tools import show
import torch
import torch.nn as nn
@ -268,6 +271,11 @@ class SimpleBNN(nn.Module):
return x
def printWeight(self):
show.DumpTensorToImage(self.lnn1.lut.weight.data.permute(1, 0).unsqueeze(-2), "./temp/lnn1.png")
show.DumpTensorToImage(self.lnn2.lut.weight.data.permute(1, 0).unsqueeze(-2), "./temp/lnn2.png")
show.DumpTensorToImage(self.lnn3.lut.weight.data.permute(1, 0).unsqueeze(-2), "./temp/lnn3.png")
show.DumpTensorToImage(self.lnn4.lut.weight.data.permute(1, 0).unsqueeze(-2), "./temp/lnn4.png")
show.DumpTensorToImage(self.lnn5.lut.weight.data.permute(1, 0).unsqueeze(-2), "./temp/lnn5.png")
pass