|
|
|
@ -46,8 +46,8 @@ class Lut(torch.autograd.Function):
|
|
|
|
|
def forward(ctx, input, weight, index):
|
|
|
|
|
ind = ((input > 0).long() * index).sum(dim=-1)
|
|
|
|
|
output = torch.gather(weight, 0, ind)
|
|
|
|
|
output = (output > 0).float()
|
|
|
|
|
output = (output - 0.5) * 2.0
|
|
|
|
|
# output = (output > 0).float()
|
|
|
|
|
# output = (output - 0.5) * 2.0
|
|
|
|
|
ctx.save_for_backward(input, weight, ind, output)
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
@ -62,31 +62,8 @@ class Lut(torch.autograd.Function):
|
|
|
|
|
grad_weight.scatter_add_(0, ind, grad_output)
|
|
|
|
|
|
|
|
|
|
if ctx.needs_input_grad[0]:
|
|
|
|
|
|
|
|
|
|
grad_input = grad_output * torch.gather(weight, 0, ind)
|
|
|
|
|
# grad_input = grad_output
|
|
|
|
|
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
|
|
|
|
|
output = output.unsqueeze(-1).repeat(1, 1, bits)
|
|
|
|
|
in_sign = ((input > 0).float() - 0.5) * 2.0
|
|
|
|
|
grad_input = grad_input * in_sign
|
|
|
|
|
grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0)
|
|
|
|
|
|
|
|
|
|
# grad_input = grad_output
|
|
|
|
|
# grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
|
|
|
|
|
# output = output.unsqueeze(-1).repeat(1, 1, bits)
|
|
|
|
|
# in_sign = ((input > 0).float() - 0.5) * 2.0
|
|
|
|
|
# out_sign = ((output > 0).float() - 0.5) * 2.0
|
|
|
|
|
# grad_sign = ((grad_input > 0).float() - 0.5) * 2.0
|
|
|
|
|
# grad_input = grad_input * in_sign * (out_sign * grad_sign)
|
|
|
|
|
# grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0)
|
|
|
|
|
|
|
|
|
|
# 需要一个动态的调整系数
|
|
|
|
|
# 能稳定的收敛
|
|
|
|
|
|
|
|
|
|
# print(in_sign[0].detach().cpu().numpy())
|
|
|
|
|
# print(out_sign[0].detach().cpu().numpy())
|
|
|
|
|
# print(grad_sign[0].detach().cpu().numpy())
|
|
|
|
|
# print(grad_input[0].detach().cpu().numpy())
|
|
|
|
|
|
|
|
|
|
return grad_input, grad_weight, None, None
|
|
|
|
|
|
|
|
|
@ -182,11 +159,11 @@ class SimpleBNN(nn.Module):
|
|
|
|
|
self.b = nn.Parameter(torch.zeros(3, 784))
|
|
|
|
|
|
|
|
|
|
# channel_repeat, input_shape, kernel_size, stride, dilation, fc
|
|
|
|
|
self.lnn1 = LutCnn(8, (BS, 1, 28, 28), 2, 2, 1, False)
|
|
|
|
|
self.lnn2 = LutCnn(1, (BS, 8, 14, 14), 2, 2, 1, False)
|
|
|
|
|
self.lnn3 = LutCnn(1, (BS, 8, 7, 7), 3, 1, 1, False)
|
|
|
|
|
self.lnn4 = LutCnn(1, (BS, 8, 5, 5), 3, 1, 1, False)
|
|
|
|
|
self.lnn5 = LutCnn(10, (BS, 8, 3, 3), 3, 1, 1)
|
|
|
|
|
self.lnn1 = LutCnn(80, (BS, 1, 28, 28), 2, 2, 1, False)
|
|
|
|
|
self.lnn2 = LutCnn(1, (BS, 80, 14, 14), 2, 2, 1, False)
|
|
|
|
|
self.lnn3 = LutCnn(1, (BS, 80, 7, 7), 3, 1, 1, False)
|
|
|
|
|
self.lnn4 = LutCnn(1, (BS, 80, 5, 5), 3, 1, 1, False)
|
|
|
|
|
self.lnn5 = LutCnn(10, (BS, 80, 3, 3), 3, 1, 1)
|
|
|
|
|
|
|
|
|
|
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
|
|
|
|
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
|
|
|
@ -290,8 +267,8 @@ class SimpleLNN(nn.Module):
|
|
|
|
|
|
|
|
|
|
torch.autograd.set_detect_anomaly(True)
|
|
|
|
|
# model = SimpleCNN().to(device)
|
|
|
|
|
# model = SimpleBNN().to(device)
|
|
|
|
|
model = SimpleLNN().to(device)
|
|
|
|
|
model = SimpleBNN().to(device)
|
|
|
|
|
# model = SimpleLNN().to(device)
|
|
|
|
|
criterion = nn.CrossEntropyLoss()
|
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
|
|
|
|
|
|
|
|
|
@ -364,8 +341,6 @@ def profiler():
|
|
|
|
|
assert False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# profiler()
|
|
|
|
|
|
|
|
|
|
for epoch in range(1, 300):
|
|
|
|
|
train(epoch)
|
|
|
|
|
test(epoch)
|
|
|
|
|