diff --git a/binary/mnist.py b/binary/mnist.py index 302c99d..4d7e1f5 100644 --- a/binary/mnist.py +++ b/binary/mnist.py @@ -232,6 +232,39 @@ class SimpleBNN(nn.Module): pass +class SimpleLNN(nn.Module): + def __init__(self): + super(SimpleLNN, self).__init__() + # group, groupBits, groupRepeat + self.lutg1 = LutGroup(1, 10, 4) + self.lutg2 = LutGroup(1, 4, 10) + + def forward(self, x, t): + batch = x.shape[0] + + x = torch.zeros_like(t).unsqueeze(-1).repeat(1, 10) + x[torch.arange(0, batch), t] = 1 + + x = self.lutg1(x) + x = self.lutg2(x) + + return x + + def printWeight(self): + print("self.lutg1") + print(self.lutg1.weight[[1, 2, 4, 8, 16, 32, 64, 128, 256, 512], :].detach().cpu().numpy()) + print("=============================") + print("=============================") + print("self.lutg1.grad") + print(self.lutg1.weight.grad[[1, 2, 4, 8, 16, 32, 64, 128, 256, 512], :].detach().cpu().numpy()) + print("=============================") + print("=============================") + # print("self.lutg2") + # print(self.lutg2.weight.detach().cpu().numpy()) + # print("=============================") + # print("=============================") + + torch.autograd.set_detect_anomaly(True) # model = SimpleCNN().to(device) model = SimpleBNN().to(device) @@ -290,6 +323,8 @@ def test(epoch): f"({accuracy:.0f}%)\n" ) model.printWeight() + + def profiler(): for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) @@ -305,6 +340,7 @@ def profiler(): prof.export_chrome_trace("local.json") assert False + for epoch in range(1, 300): train(epoch) test(epoch)