Add weight dump.
This commit is contained in:
parent
9acd89c98e
commit
ab65f98e39
|
@ -1,7 +1,10 @@
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
||||||
|
|
||||||
|
sys.path.append("..")
|
||||||
|
from tools import show
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
@ -268,6 +271,11 @@ class SimpleBNN(nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def printWeight(self):
|
def printWeight(self):
|
||||||
|
show.DumpTensorToImage(self.lnn1.lut.weight.data.permute(1, 0).unsqueeze(-2), "./temp/lnn1.png")
|
||||||
|
show.DumpTensorToImage(self.lnn2.lut.weight.data.permute(1, 0).unsqueeze(-2), "./temp/lnn2.png")
|
||||||
|
show.DumpTensorToImage(self.lnn3.lut.weight.data.permute(1, 0).unsqueeze(-2), "./temp/lnn3.png")
|
||||||
|
show.DumpTensorToImage(self.lnn4.lut.weight.data.permute(1, 0).unsqueeze(-2), "./temp/lnn4.png")
|
||||||
|
show.DumpTensorToImage(self.lnn5.lut.weight.data.permute(1, 0).unsqueeze(-2), "./temp/lnn5.png")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue