Add LNN.
This commit is contained in:
parent
50fb9bf6dc
commit
fa15680aa6
|
@ -232,6 +232,39 @@ class SimpleBNN(nn.Module):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleLNN(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(SimpleLNN, self).__init__()
|
||||||
|
# group, groupBits, groupRepeat
|
||||||
|
self.lutg1 = LutGroup(1, 10, 4)
|
||||||
|
self.lutg2 = LutGroup(1, 4, 10)
|
||||||
|
|
||||||
|
def forward(self, x, t):
|
||||||
|
batch = x.shape[0]
|
||||||
|
|
||||||
|
x = torch.zeros_like(t).unsqueeze(-1).repeat(1, 10)
|
||||||
|
x[torch.arange(0, batch), t] = 1
|
||||||
|
|
||||||
|
x = self.lutg1(x)
|
||||||
|
x = self.lutg2(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def printWeight(self):
|
||||||
|
print("self.lutg1")
|
||||||
|
print(self.lutg1.weight[[1, 2, 4, 8, 16, 32, 64, 128, 256, 512], :].detach().cpu().numpy())
|
||||||
|
print("=============================")
|
||||||
|
print("=============================")
|
||||||
|
print("self.lutg1.grad")
|
||||||
|
print(self.lutg1.weight.grad[[1, 2, 4, 8, 16, 32, 64, 128, 256, 512], :].detach().cpu().numpy())
|
||||||
|
print("=============================")
|
||||||
|
print("=============================")
|
||||||
|
# print("self.lutg2")
|
||||||
|
# print(self.lutg2.weight.detach().cpu().numpy())
|
||||||
|
# print("=============================")
|
||||||
|
# print("=============================")
|
||||||
|
|
||||||
|
|
||||||
torch.autograd.set_detect_anomaly(True)
|
torch.autograd.set_detect_anomaly(True)
|
||||||
# model = SimpleCNN().to(device)
|
# model = SimpleCNN().to(device)
|
||||||
model = SimpleBNN().to(device)
|
model = SimpleBNN().to(device)
|
||||||
|
@ -290,6 +323,8 @@ def test(epoch):
|
||||||
f"({accuracy:.0f}%)\n"
|
f"({accuracy:.0f}%)\n"
|
||||||
)
|
)
|
||||||
model.printWeight()
|
model.printWeight()
|
||||||
|
|
||||||
|
|
||||||
def profiler():
|
def profiler():
|
||||||
for batch_idx, (data, target) in enumerate(train_loader):
|
for batch_idx, (data, target) in enumerate(train_loader):
|
||||||
data, target = data.to(device), target.to(device)
|
data, target = data.to(device), target.to(device)
|
||||||
|
@ -305,6 +340,7 @@ def profiler():
|
||||||
prof.export_chrome_trace("local.json")
|
prof.export_chrome_trace("local.json")
|
||||||
assert False
|
assert False
|
||||||
|
|
||||||
|
|
||||||
for epoch in range(1, 300):
|
for epoch in range(1, 300):
|
||||||
train(epoch)
|
train(epoch)
|
||||||
test(epoch)
|
test(epoch)
|
||||||
|
|
Loading…
Reference in New Issue