Compare commits

...

5 Commits

Author SHA1 Message Date
Colin a676025d20 Update speed test. 2025-06-09 17:54:05 +08:00
Colin fa15680aa6 Add LNN. 2025-06-09 16:00:13 +08:00
Colin 50fb9bf6dc Add profile function. 2025-06-09 15:57:53 +08:00
Colin 878c690ac4 Refine LUTCNN, keep accuracy to ~93 2025-06-09 15:56:03 +08:00
Colin 5d03634595 Revert mnist to new backwoard of LUT 855296be55. 2025-06-09 15:04:53 +08:00
2 changed files with 13 additions and 34 deletions

View File

@ -46,8 +46,8 @@ class Lut(torch.autograd.Function):
def forward(ctx, input, weight, index):
ind = ((input > 0).long() * index).sum(dim=-1)
output = torch.gather(weight, 0, ind)
output = (output > 0).float()
output = (output - 0.5) * 2.0
# output = (output > 0).float()
# output = (output - 0.5) * 2.0
ctx.save_for_backward(input, weight, ind, output)
return output
@ -62,31 +62,8 @@ class Lut(torch.autograd.Function):
grad_weight.scatter_add_(0, ind, grad_output)
if ctx.needs_input_grad[0]:
grad_input = grad_output * torch.gather(weight, 0, ind)
# grad_input = grad_output
grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
output = output.unsqueeze(-1).repeat(1, 1, bits)
in_sign = ((input > 0).float() - 0.5) * 2.0
grad_input = grad_input * in_sign
grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0)
# grad_input = grad_output
# grad_input = grad_input.unsqueeze(-1).repeat(1, 1, bits)
# output = output.unsqueeze(-1).repeat(1, 1, bits)
# in_sign = ((input > 0).float() - 0.5) * 2.0
# out_sign = ((output > 0).float() - 0.5) * 2.0
# grad_sign = ((grad_input > 0).float() - 0.5) * 2.0
# grad_input = grad_input * in_sign * (out_sign * grad_sign)
# grad_input = grad_input * (((torch.rand_like(grad_input) - 0.5) / 100) + 1.0)
# 需要一个动态的调整系数
# 能稳定的收敛
# print(in_sign[0].detach().cpu().numpy())
# print(out_sign[0].detach().cpu().numpy())
# print(grad_sign[0].detach().cpu().numpy())
# print(grad_input[0].detach().cpu().numpy())
return grad_input, grad_weight, None, None
@ -182,11 +159,11 @@ class SimpleBNN(nn.Module):
self.b = nn.Parameter(torch.zeros(3, 784))
# channel_repeat, input_shape, kernel_size, stride, dilation, fc
self.lnn1 = LutCnn(8, (BS, 1, 28, 28), 2, 2, 1, False)
self.lnn2 = LutCnn(1, (BS, 8, 14, 14), 2, 2, 1, False)
self.lnn3 = LutCnn(1, (BS, 8, 7, 7), 3, 1, 1, False)
self.lnn4 = LutCnn(1, (BS, 8, 5, 5), 3, 1, 1, False)
self.lnn5 = LutCnn(10, (BS, 8, 3, 3), 3, 1, 1)
self.lnn1 = LutCnn(80, (BS, 1, 28, 28), 2, 2, 1, False)
self.lnn2 = LutCnn(1, (BS, 80, 14, 14), 2, 2, 1, False)
self.lnn3 = LutCnn(1, (BS, 80, 7, 7), 3, 1, 1, False)
self.lnn4 = LutCnn(1, (BS, 80, 5, 5), 3, 1, 1, False)
self.lnn5 = LutCnn(10, (BS, 80, 3, 3), 3, 1, 1)
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
@ -290,8 +267,8 @@ class SimpleLNN(nn.Module):
torch.autograd.set_detect_anomaly(True)
# model = SimpleCNN().to(device)
# model = SimpleBNN().to(device)
model = SimpleLNN().to(device)
model = SimpleBNN().to(device)
# model = SimpleLNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
@ -364,8 +341,6 @@ def profiler():
assert False
# profiler()
for epoch in range(1, 300):
train(epoch)
test(epoch)

View File

@ -6,3 +6,7 @@
v100_32G_PCIE batch=1 bf16=false 5.5s/it
v100_32G_PCIE batch=1 bf16=true 8.7s/it
v100_32G_PCIE batch=4 bf16=false 21.53s/it
A100_40G_PCIE batch=4 bf16=false 3.58s/it