This commit is contained in:
Colin 2025-06-09 16:00:13 +08:00
parent 50fb9bf6dc
commit fa15680aa6
1 changed files with 36 additions and 0 deletions

View File

@ -232,6 +232,39 @@ class SimpleBNN(nn.Module):
pass
class SimpleLNN(nn.Module):
def __init__(self):
super(SimpleLNN, self).__init__()
# group, groupBits, groupRepeat
self.lutg1 = LutGroup(1, 10, 4)
self.lutg2 = LutGroup(1, 4, 10)
def forward(self, x, t):
batch = x.shape[0]
x = torch.zeros_like(t).unsqueeze(-1).repeat(1, 10)
x[torch.arange(0, batch), t] = 1
x = self.lutg1(x)
x = self.lutg2(x)
return x
def printWeight(self):
print("self.lutg1")
print(self.lutg1.weight[[1, 2, 4, 8, 16, 32, 64, 128, 256, 512], :].detach().cpu().numpy())
print("=============================")
print("=============================")
print("self.lutg1.grad")
print(self.lutg1.weight.grad[[1, 2, 4, 8, 16, 32, 64, 128, 256, 512], :].detach().cpu().numpy())
print("=============================")
print("=============================")
# print("self.lutg2")
# print(self.lutg2.weight.detach().cpu().numpy())
# print("=============================")
# print("=============================")
torch.autograd.set_detect_anomaly(True)
# model = SimpleCNN().to(device)
model = SimpleBNN().to(device)
@ -290,6 +323,8 @@ def test(epoch):
f"({accuracy:.0f}%)\n"
)
model.printWeight()
def profiler():
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
@ -305,6 +340,7 @@ def profiler():
prof.export_chrome_trace("local.json")
assert False
for epoch in range(1, 300):
train(epoch)
test(epoch)